| {_treat_html(i)} | " html_code += f"
|---|
| {_treat_html(i)} | " html_code += "
|`)?(\w*)\s*=\s*", cell['source'])
if match is not None: doc_fns[match.group(1)] = i
return doc_fns
def link_markdown_cells(cells, modules):
"Create documentation links for all cells in markdown with backticks."
for i, cell in enumerate(cells):
if cell['cell_type'] == 'markdown':
cell['source'] = link_docstring(modules, cell['source'])
def get_insert_idx(pos_dict, name):
"Return the position to insert a given function doc in a notebook."
keys,i = list(pos_dict.keys()),0
while i < len(keys) and str.lower(keys[i]) < str.lower(name): i+=1
if i == len(keys): return -1
else: return pos_dict[keys[i]]
def update_pos(pos_dict, start_key, nbr=2):
"Update the `pos_dict` by moving all positions after `start_key` by `nbr`."
for key,idx in pos_dict.items():
if str.lower(key) >= str.lower(start_key): pos_dict[key] += nbr
return pos_dict
def insert_cells(cells, pos_dict, ft_name, append=False):
"Insert the function doc `cells` at their correct position and updates `pos_dict`."
idx = get_insert_idx(pos_dict, ft_name)
if append or idx == -1: cells += [get_doc_cell(ft_name), get_empty_cell()]
else:
cells.insert(idx, get_doc_cell(ft_name))
cells.insert(idx+1, get_empty_cell())
pos_dict = update_pos(pos_dict, ft_name, 2)
return cells, pos_dict
def get_doc_path(mod, dest_path):
strip_name = strip_fastai(mod.__name__)
return os.path.join(dest_path,f'{strip_name}.ipynb')
def generate_missing_metadata(dest_file):
fn = Path(dest_file)
meta_fn = fn.parent/'jekyll_metadata.ipynb'
if not fn.exists() or not meta_fn.exists(): return print('Could not find notebooks:', fn, meta_fn)
metadata_nb = read_nb(meta_fn)
if has_metadata_cell(metadata_nb['cells'], fn.name): return
nb = read_nb(fn)
jmd = nb['metadata'].get('jekyll', {})
fmt_params = ''
for k,v in jmd.items(): fmt_params += f',\n {k}={stringify(v)}'
metadata_cell = get_code_cell(f"update_nb_metadata('{Path(fn).name}'{fmt_params})", hidden=False)
metadata_nb['cells'].append(metadata_cell)
write_nb(metadata_nb, meta_fn)
def update_nb_metadata(nb_path=None, title=None, summary=None, keywords='fastai', overwrite=True, **kwargs):
"Creates jekyll metadata for given notebook path."
nb = read_nb(nb_path)
data = {'title': title, 'summary': summary, 'keywords': keywords, **kwargs}
data = {k:v for (k,v) in data.items() if v is not None} # remove none values
if not data: return
nb['metadata']['jekyll'] = data
write_nb(nb, nb_path)
NotebookNotary().sign(nb)
def has_metadata_cell(cells, fn):
for c in cells:
if re.search(f"update_nb_metadata\('{fn}'", c['source']): return c
def stringify(s): return f'\'{s}\'' if isinstance(s, str) else s
IMPORT_RE = re.compile(r"from (fastai[\.\w_]*)")
def get_imported_modules(cells, nb_module_name=''):
"Finds all submodules of notebook - sorted by submodules > top level modules > manual imports. This gives notebook imports priority"
module_names = get_top_level_modules()
nb_imports = [match.group(1) for cell in cells for match in IMPORT_RE.finditer(cell['source']) if cell['cell_type'] == 'code']
parts = nb_module_name.split('.')
parent_modules = ['.'.join(parts[:(x+1)]) for x in range_of(parts)] # Imports parent modules - a.b.c = [a, a.b, a.b.c]
all_modules = module_names + nb_imports + parent_modules
mods = [import_mod(m, ignore_errors=True) for m in all_modules]
return [m for m in mods if m is not None]
def get_top_level_modules(num_levels=1):
mod_dir = Path(import_mod('fastai').__file__).parent
filtered_n = filter(lambda x: x.count('.')<=num_levels, get_module_names(mod_dir))
return sorted(filtered_n, key=lambda s: s.count('.'), reverse=True) # Submodules first (sorted by periods)
NEW_FT_HEADER = '## New Methods - Please document or move to the undocumented section'
UNDOC_HEADER = '## Undocumented Methods - Methods moved below this line will intentionally be hidden'
def parse_sections(cells):
old_cells, undoc_cells, new_cells = [], [], []
current_section = old_cells
for cell in cells:
if cell['cell_type'] == 'markdown':
if re.match(UNDOC_HEADER, cell['source']): current_section = undoc_cells
if re.match(NEW_FT_HEADER, cell['source']): current_section = new_cells
current_section.append(cell)
undoc_cells = undoc_cells or [get_md_cell(UNDOC_HEADER)]
new_cells = new_cells or [get_md_cell(NEW_FT_HEADER)]
return old_cells, undoc_cells, new_cells
def remove_undoc_cells(cells):
old, _, _ = parse_sections(cells)
return old
# currently code vbox sub-cells mainly
def remove_code_cell_jupyter_widget_state_elem(cells):
for c in cells:
if c['cell_type'] == 'code':
if 'outputs' in c:
c['outputs'] = [l for l in c['outputs'] if not ('data' in l and 'application/vnd.jupyter.widget-view+json' in l.data)]
return cells
def update_module_page(mod, dest_path='.'):
"Update the documentation notebook of a given module."
doc_path = get_doc_path(mod, dest_path)
strip_name = strip_fastai(mod.__name__)
nb = read_nb(doc_path)
cells = nb['cells']
link_markdown_cells(cells, get_imported_modules(cells, mod.__name__))
type_dict = read_nb_types(cells)
gvar_map = get_global_vars(mod)
for name in get_exports(mod):
if name not in gvar_map: continue
code = gvar_map[name]
if name in type_dict: cells[type_dict[name]] = get_md_cell(code)
else: cells.append(get_md_cell(code))
pos_dict = read_nb_content(cells, strip_name)
ft_names = get_ft_names(mod, include_inner=True)
new_fts = list(set(ft_names) - set(pos_dict.keys()))
if new_fts: print(f'Found new fuctions for {mod}. Please document:\n{new_fts}')
existing, undoc_cells, new_cells = parse_sections(cells)
for ft_name in new_fts: new_cells.extend([get_doc_cell(ft_name), get_empty_cell()])
if len(new_cells) > 1: nb['cells'] = existing + undoc_cells + new_cells
write_nb(nb, doc_path)
return doc_path
def link_nb(nb_path):
nb = read_nb(nb_path)
cells = nb['cells']
link_markdown_cells(cells, get_imported_modules(cells, Path(nb_path).stem))
write_nb(nb, nb_path)
NotebookNotary().sign(read_nb(nb_path))
def get_module_from_notebook(doc_path):
"Find module given a source path. Assume it belongs to fastai directory"
return f'fastai.{Path(doc_path).stem}'
def check_nbconvert_version():
import nbconvert
assert nbconvert.version_info >= (5,4,0), "Please update nbconvert to >=5.4 for consistent .html output"
def update_notebooks(source_path, dest_path=None, update_html=True, document_new_fns=False,
update_nb_links=True, html_path=None, force=False):
"`source_path` can be a directory or a file. Assume all modules reside in the fastai directory."
from .convert2html import convert_nb
source_path = Path(source_path)
if source_path.is_file():
dest_path = source_path.parent if dest_path is None else Path(dest_path)
html_path = dest_path/'..'/'docs' if html_path is None else Path(html_path)
doc_path = source_path
assert source_path.suffix == '.ipynb', 'Must update from notebook or module'
if document_new_fns:
mod = import_mod(get_module_from_notebook(source_path))
if not mod: print('Could not find module for path:', source_path)
elif mod.__file__.endswith('__init__.py'): pass
else: update_module_page(mod, dest_path)
generate_missing_metadata(doc_path)
if update_nb_links:
print(f'Updating notebook {doc_path}. Please wait...')
link_nb(doc_path)
execute_nb(doc_path, {'metadata': {'path': doc_path.parent}}, show_doc_only=True)
if update_html:
check_nbconvert_version()
html_fn = html_path/doc_path.with_suffix('.html').name
if not force and html_fn.is_file():
in_mod = os.path.getmtime(doc_path)
out_mod = os.path.getmtime(html_fn)
if in_mod < out_mod: return
convert_nb(doc_path, html_path)
elif (source_path.name.startswith('fastai.')):
# Do module update
assert dest_path is not None, 'To update a module, you must specify a destination folder for where notebook resides'
mod = import_mod(source_path.name)
if not mod: return print('Could not find module for:', source_path)
doc_path = Path(dest_path)/(strip_fastai(mod.__name__)+'.ipynb')
if not doc_path.exists():
print('Notebook does not exist. Creating:', doc_path)
create_module_page(mod, dest_path)
update_notebooks(doc_path, dest_path=dest_path, update_html=update_html, document_new_fns=document_new_fns,
update_nb_links=update_nb_links, html_path=html_path)
elif source_path.is_dir():
for f in sorted(Path(source_path).glob('*.ipynb')):
update_notebooks(f, dest_path=dest_path, update_html=update_html, document_new_fns=document_new_fns,
update_nb_links=update_nb_links, html_path=html_path)
else: print('Could not resolve source file:', source_path)
================================================
FILE: fastai/gen_doc/hide.tpl
================================================
{%- extends 'basic.tpl' -%}
{% block input_group -%}
{%- if cell.metadata.hide_input or nb.metadata.hide_input -%}
{%- else -%}
{{ super() }}
{%- endif -%}
{% endblock input_group %}
{% block output_group -%}
{%- if cell.metadata.hide_output -%}
{%- else -%}
{{ super() }}
{%- endif -%}
{% endblock output_group %}
{% block output_area_prompt %}
{%- if cell.metadata.hide_input or nb.metadata.hide_input -%}
{%- else -%}
{{ super() }}
{%- endif -%}
{% endblock output_area_prompt %}
================================================
FILE: fastai/gen_doc/jekyll.tpl
================================================
{%- extends 'hide.tpl' -%}{% block body %}---
{% if resources.toc != "" and resources.toc != nil %}toc: {{resources.toc}}{% endif %}
{% if resources.title != "" and resources.title != nil %}title: {{resources.title}}{% endif %}
keywords: {{resources.keywords}}
sidebar: home_sidebar
{% if resources.tags != "" and resources.tags != nil %}tags: {{resources.tags}}{% endif %}
{% if resources.summary != "" and resources.summary != nil %}summary: "{{resources.summary}}"{% endif %}
---
{% include 'autogen.tpl' %}
{{ super() }}
{%- endblock body %}
================================================
FILE: fastai/gen_doc/nbdoc.py
================================================
"`gen_doc.nbdoc` generates notebook documentation from module functions and links to correct places"
import inspect,importlib,enum,os,re,nbconvert
from IPython.core.display import display, Markdown, HTML
from nbconvert import HTMLExporter
from IPython.core import page
from IPython import get_ipython
from typing import Dict, Any, AnyStr, List, Sequence, TypeVar, Tuple, Optional, Union
from .docstrings import *
from .core import *
from ..torch_core import *
from .nbtest import get_pytest_html
from ..utils.ipython import IS_IN_COLAB
__all__ = ['get_fn_link', 'link_docstring', 'show_doc', 'get_ft_names', 'md2html',
'get_exports', 'show_video', 'show_video_from_youtube', 'import_mod', 'get_source_link',
'is_enum', 'jekyll_note', 'jekyll_warn', 'jekyll_important', 'doc']
MODULE_NAME = 'fastai'
SOURCE_URL = 'https://github.com/fastai/fastai/blob/master/'
PYTORCH_DOCS = 'https://pytorch.org/docs/stable/'
FASTAI_DOCS = 'https://docs.fast.ai'
use_relative_links = True
_typing_names = {t:n for t,n in fastai_types.items() if t.__module__=='typing'}
arg_prefixes = {inspect._VAR_POSITIONAL: '\*', inspect._VAR_KEYWORD:'\*\*'}
def is_enum(cls): return cls == enum.Enum or cls == enum.EnumMeta
def link_type(arg_type, arg_name=None, include_bt:bool=True):
"Create link to documentation."
arg_name = arg_name or fn_name(arg_type)
if include_bt: arg_name = code_esc(arg_name)
if belongs_to_module(arg_type, 'torch') and ('Tensor' not in arg_name): return f'[{arg_name}]({get_pytorch_link(arg_type)})'
if is_fastai_class(arg_type): return f'[{arg_name}]({get_fn_link(arg_type)})'
return arg_name
def is_fastai_class(t): return belongs_to_module(t, MODULE_NAME)
def belongs_to_module(t, module_name):
"Check if `t` belongs to `module_name`."
if hasattr(t, '__func__'): return belongs_to_module(t.__func__, module_name)
if not inspect.getmodule(t): return False
return inspect.getmodule(t).__name__.startswith(module_name)
def code_esc(s): return f'`{s}`'
def type_repr(t):
if t in _typing_names: return link_type(t, _typing_names[t])
if isinstance(t, partial): return partial_repr(t)
if hasattr(t, '__forward_arg__'): return link_type(t.__forward_arg__)
elif getattr(t, '__args__', None):
args = t.__args__
if len(args)==2 and args[1] == type(None):
return f'`Optional`\[{type_repr(args[0])}\]'
reprs = ', '.join([type_repr(o) for o in args])
return f'{link_type(t)}\[{reprs}\]'
else: return link_type(t)
def partial_repr(t):
args = (t.func,) + t.args + tuple([f'{k}={v}' for k,v in t.keywords.items()])
reprs = ', '.join([link_type(o) for o in args])
return f'partial({reprs})'
def anno_repr(a): return type_repr(a)
def format_param(p):
"Formats function param to `param1:Type=val`. Font weights: param1=bold, val=bold+italic"
arg_prefix = arg_prefixes.get(p.kind, '') # asterisk prefix for *args and **kwargs
res = f"**{arg_prefix}{code_esc(p.name)}**"
if hasattr(p, 'annotation') and p.annotation != p.empty: res += f':{anno_repr(p.annotation)}'
if p.default != p.empty:
default = getattr(p.default, 'func', p.default)
default = getattr(default, '__name__', default)
res += f'=***`{repr(default)}`***'
return res
def format_ft_def(func, full_name:str=None)->str:
"Format and link `func` definition to show in documentation"
sig = inspect.signature(func)
name = f'{full_name or func.__name__}'
fmt_params = [format_param(param) for name,param
in sig.parameters.items() if name not in ('self','cls')]
arg_str = f"({', '.join(fmt_params)})"
if sig.return_annotation and (sig.return_annotation != sig.empty): arg_str += f" → {anno_repr(sig.return_annotation)}"
if is_fastai_class(type(func)): arg_str += f" :: {link_type(type(func))}"
f_name = f"class {name}" if inspect.isclass(func) else name
return f'{f_name}',f'{name}{arg_str}'
def get_enum_doc(elt, full_name:str)->str:
"Formatted enum documentation."
vals = ', '.join(elt.__members__.keys())
return f'{code_esc(full_name)}',f'Enum = [{vals}]'
def get_cls_doc(elt, full_name:str)->str:
"Class definition."
parent_class = inspect.getclasstree([elt])[-1][0][1][0]
name,args = format_ft_def(elt, full_name)
if parent_class != object: args += f' :: {link_type(parent_class, include_bt=True)}'
return name,args
def show_doc(elt, doc_string:bool=True, full_name:str=None, arg_comments:dict=None, title_level=None, alt_doc_string:str='',
ignore_warn:bool=False, markdown=True, show_tests=True):
"Show documentation for element `elt`. Supported types: class, Callable, and enum."
arg_comments = ifnone(arg_comments, {})
anchor_id = get_anchor(elt)
elt = getattr(elt, '__func__', elt)
full_name = full_name or fn_name(elt)
if inspect.isclass(elt):
if is_enum(elt.__class__): name,args = get_enum_doc(elt, full_name)
else: name,args = get_cls_doc(elt, full_name)
elif isinstance(elt, Callable): name,args = format_ft_def(elt, full_name)
else: raise Exception(f'doc definition not supported for {full_name}')
source_link = get_function_source(elt) if is_fastai_class(elt) else ""
test_link, test_modal = get_pytest_html(elt, anchor_id=anchor_id) if show_tests else ('', '')
title_level = ifnone(title_level, 2 if inspect.isclass(elt) else 4)
doc = f'{name}{source_link}{test_link} '
doc += f'\n\n> {args}\n\n'
doc += f'{test_modal}'
if doc_string and (inspect.getdoc(elt) or arg_comments):
doc += format_docstring(elt, arg_comments, alt_doc_string, ignore_warn) + ' '
if markdown: display(Markdown(doc))
else: return doc
def md2html(md):
if nbconvert.__version__ < '5.5.0': return HTMLExporter().markdown2html(md)
else: return HTMLExporter().markdown2html(defaultdict(lambda: defaultdict(dict)), md)
def doc(elt):
"Show `show_doc` info in preview window along with link to full docs."
global use_relative_links
use_relative_links = False
elt = getattr(elt, '__func__', elt)
md = show_doc(elt, markdown=False)
if is_fastai_class(elt):
md += f'\n\nShow in docs'
output = md2html(md)
use_relative_links = True
if IS_IN_COLAB: get_ipython().run_cell_magic(u'html', u'', output)
else:
try: page.page({'text/html': output})
except: display(Markdown(md))
def format_docstring(elt, arg_comments:dict={}, alt_doc_string:str='', ignore_warn:bool=False)->str:
"Merge and format the docstring definition with `arg_comments` and `alt_doc_string`."
parsed = ""
doc = parse_docstring(inspect.getdoc(elt))
description = alt_doc_string or f"{doc['short_description']} {doc['long_description']}"
if description: parsed += f'\n\n{link_docstring(inspect.getmodule(elt), description)}'
resolved_comments = {**doc.get('comments', {}), **arg_comments} # arg_comments takes priority
args = inspect.getfullargspec(elt).args if not is_enum(elt.__class__) else elt.__members__.keys()
if resolved_comments: parsed += '\n'
for a in resolved_comments:
parsed += f'\n- *{a}*: {resolved_comments[a]}'
if a not in args and not ignore_warn: warn(f'Doc arg mismatch: {a}')
return_comment = arg_comments.get('return') or doc.get('return')
if return_comment: parsed += f'\n\n*return*: {return_comment}'
return parsed
_modvars = {}
def replace_link(m):
keyword = m.group(1) or m.group(2)
elt = find_elt(_modvars, keyword)
if elt is None: return m.group()
return link_type(elt, arg_name=keyword)
# Finds all places with a backtick but only if it hasn't already been linked
BT_REGEX = re.compile("\[`([^`]*)`\](?:\([^)]*\))|`([^`]*)`") # matches [`key`](link) or `key`
def link_docstring(modules, docstring:str, overwrite:bool=False)->str:
"Search `docstring` for backticks and attempt to link those functions to respective documentation."
mods = listify(modules)
for mod in mods: _modvars.update(mod.__dict__) # concat all module definitions
return re.sub(BT_REGEX, replace_link, docstring)
def find_elt(modvars, keyword, match_last=False):
"Attempt to resolve keywords such as Learner.lr_find. `match_last` starts matching from last component."
keyword = strip_fastai(keyword)
if keyword in modvars: return modvars[keyword]
comps = keyword.split('.')
comp_elt = modvars.get(comps[0])
if hasattr(comp_elt, '__dict__'): return find_elt(comp_elt.__dict__, '.'.join(comps[1:]), match_last=match_last)
def import_mod(mod_name:str, ignore_errors=False):
"Return module from `mod_name`."
splits = str.split(mod_name, '.')
try:
if len(splits) > 1 : mod = importlib.import_module('.' + '.'.join(splits[1:]), splits[0])
else: mod = importlib.import_module(mod_name)
return mod
except:
if not ignore_errors: print(f"Module {mod_name} doesn't exist.")
def show_doc_from_name(mod_name, ft_name:str, doc_string:bool=True, arg_comments:dict={}, alt_doc_string:str=''):
"Show documentation for `ft_name`, see `show_doc`."
mod = import_mod(mod_name)
splits = str.split(ft_name, '.')
assert hasattr(mod, splits[0]), print(f"Module {mod_name} doesn't have a function named {splits[0]}.")
elt = getattr(mod, splits[0])
for i,split in enumerate(splits[1:]):
assert hasattr(elt, split), print(f"Class {'.'.join(splits[:i+1])} doesn't have a function named {split}.")
elt = getattr(elt, split)
show_doc(elt, doc_string, ft_name, arg_comments, alt_doc_string)
def get_exports(mod):
public_names = mod.__all__ if hasattr(mod, '__all__') else dir(mod)
#public_names.sort(key=str.lower)
return [o for o in public_names if not o.startswith('_')]
def get_ft_names(mod, include_inner=False)->List[str]:
"Return all the functions of module `mod`."
# If the module has an attribute __all__, it picks those.
# Otherwise, it returns all the functions defined inside a module.
fn_names = []
for elt_name in get_exports(mod):
elt = getattr(mod,elt_name)
#This removes the files imported from elsewhere
try: fname = inspect.getfile(elt)
except: continue
if mod.__file__.endswith('__init__.py'):
if inspect.ismodule(elt): fn_names.append(elt_name)
else: continue
else:
if (fname != mod.__file__): continue
if inspect.isclass(elt) or inspect.isfunction(elt): fn_names.append(elt_name)
else: continue
if include_inner and inspect.isclass(elt) and not is_enum(elt.__class__):
fn_names.extend(get_inner_fts(elt))
return fn_names
def get_inner_fts(elt)->List[str]:
"List the inner functions of a class."
fts = []
for ft_name in elt.__dict__.keys():
if ft_name.startswith('_'): continue
ft = getattr(elt, ft_name)
if inspect.isfunction(ft): fts.append(f'{elt.__name__}.{ft_name}')
if inspect.ismethod(ft): fts.append(f'{elt.__name__}.{ft_name}')
if inspect.isclass(ft): fts += [f'{elt.__name__}.{n}' for n in get_inner_fts(ft)]
return fts
def get_module_toc(mod_name):
"Display table of contents for given `mod_name`."
mod = import_mod(mod_name)
ft_names = mod.__all__ if hasattr(mod,'__all__') else get_ft_names(mod)
ft_names.sort(key = str.lower)
tabmat = ''
for ft_name in ft_names:
tabmat += f'- [{ft_name}](#{ft_name})\n'
elt = getattr(mod, ft_name)
if inspect.isclass(elt) and not is_enum(elt.__class__):
in_ft_names = get_inner_fts(elt)
for name in in_ft_names:
tabmat += f' - [{name}](#{name})\n'
display(Markdown(tabmat))
def show_video(url):
"Display video in `url`."
data = f''
return display(HTML(data))
def show_video_from_youtube(code, start=0):
"Display video from Youtube with a `code` and a `start` time."
url = f'https://www.youtube.com/embed/{code}?start={start}&rel=0&controls=0&showinfo=0'
return show_video(url)
def get_anchor(fn)->str:
if hasattr(fn,'__qualname__'): return fn.__qualname__
if inspect.ismethod(fn): return fn_name(fn.__self__) + '.' + fn_name(fn)
return fn_name(fn)
def fn_name(ft)->str:
if ft.__hash__ and ft in _typing_names: return _typing_names[ft]
if hasattr(ft, '__name__'): return ft.__name__
elif hasattr(ft,'_name') and ft._name: return ft._name
elif hasattr(ft,'__origin__'): return str(ft.__origin__).split('.')[-1]
else: return str(ft).split('.')[-1]
def get_fn_link(ft)->str:
"Return function link to notebook documentation of `ft`. Private functions link to source code"
ft = getattr(ft, '__func__', ft)
anchor = strip_fastai(get_anchor(ft))
module_name = strip_fastai(get_module_name(ft))
base = '' if use_relative_links else FASTAI_DOCS
return f'{base}/{module_name}.html#{anchor}'
def get_module_name(ft)->str: return inspect.getmodule(ft).__name__
def get_pytorch_link(ft)->str:
"Returns link to pytorch docs of `ft`."
name = ft.__name__
ext = '.html'
if name == 'device': return f'{PYTORCH_DOCS}tensor_attributes{ext}#torch-device'
if name == 'Tensor': return f'{PYTORCH_DOCS}tensors{ext}#torch-tensor'
if name.startswith('torchvision'):
doc_path = get_module_name(ft).replace('.', '/')
if inspect.ismodule(ft): name = name.replace('.', '-')
return f'{PYTORCH_DOCS}{doc_path}{ext}#{name}'
if name.startswith('torch.nn') and inspect.ismodule(ft): # nn.functional is special case
nn_link = name.replace('.', '-')
return f'{PYTORCH_DOCS}nn{ext}#{nn_link}'
paths = get_module_name(ft).split('.')
if len(paths) == 1: return f'{PYTORCH_DOCS}{paths[0]}{ext}#{paths[0]}.{name}'
offset = 1 if paths[1] == 'utils' else 0 # utils is a pytorch special case
doc_path = paths[1+offset]
if inspect.ismodule(ft): return f'{PYTORCH_DOCS}{doc_path}{ext}#module-{name}'
fnlink = '.'.join(paths[:(2+offset)]+[name])
return f'{PYTORCH_DOCS}{doc_path}{ext}#{fnlink}'
def get_source_link(file, line, display_text="[source]", **kwargs)->str:
"Returns github link for given file"
link = f"{SOURCE_URL}{file}#L{line}"
if display_text is None: return link
return f'{display_text}'
def get_function_source(ft, **kwargs)->str:
"Returns link to `ft` in source code."
try: line = inspect.getsourcelines(ft)[1]
except Exception: return ''
mod_path = get_module_name(ft).replace('.', '/') + '.py'
return get_source_link(mod_path, line, **kwargs)
def title_md(s:str, title_level:int, markdown=True):
res = '#' * title_level
if title_level: res += ' '
return Markdown(res+s) if markdown else (res+s)
def jekyll_div(s,c,h,icon=None):
icon = ifnone(icon,c)
res = f' {h}: {s}'
display(Markdown(res))
def jekyll_note(s): return jekyll_div(s,'info','Note')
def jekyll_warn(s): return jekyll_div(s,'danger','Warning', 'exclamation')
def jekyll_important(s): return jekyll_div(s,'warning','Important')
================================================
FILE: fastai/gen_doc/nbtest.py
================================================
"`gen_doc.nbtest` shows pytest documentation for module functions"
import inspect, os, re
from os.path import abspath, dirname, join
from collections import namedtuple
from fastai.gen_doc import nbdoc
from ..imports.core import *
from .core import ifnone
from .doctest import get_parent_func, relative_test_path, get_func_fq_name, DB_NAME
from nbconvert import HTMLExporter
from IPython.core import page
from IPython.core.display import display, Markdown, HTML
__all__ = ['show_test', 'doctest', 'find_related_tests', 'lookup_db', 'find_test_matches', 'find_test_files', 'fuzzy_test_match', 'get_pytest_html']
TestFunctionMatch = namedtuple('TestFunctionMatch', ['line_number', 'line'])
def show_test(elt)->str:
"Show associated tests for a fastai function/class"
md = build_tests_markdown(elt)
display(Markdown(md))
def doctest(elt):
"Inline notebook popup for `show_test`"
md = build_tests_markdown(elt)
output = nbdoc.md2html(md)
try: page.page({'text/html': output})
except: display(Markdown(md))
def build_tests_markdown(elt):
fn_name = nbdoc.fn_name(elt)
md = ''
db_matches = [get_links(t) for t in lookup_db(elt)]
md += tests2md(db_matches, '')
try:
related = [get_links(t) for t in find_related_tests(elt)]
other_tests = [k for k in OrderedDict.fromkeys(related) if k not in db_matches]
md += tests2md(other_tests, f'Some other tests where `{fn_name}` is used:')
except OSError as e: pass
if len(md.strip())==0:
return (f'No tests found for `{fn_name}`.'
' To contribute a test please refer to [this guide](/dev/test.html)'
' and [this discussion](https://forums.fast.ai/t/improving-expanding-functional-tests/32929).')
return (f'Tests found for `{fn_name}`: {md}'
'\n\nTo run tests please refer to this [guide](/dev/test.html#quick-guide).')
def tests2md(tests, type_label:str):
if not tests: return ''
md = [f'\n\n{type_label}'] + [f'* `{cmd}` {link}' for link,cmd in sorted(tests, key=lambda k: k[1])]
return '\n'.join(md)
def get_pytest_html(elt, anchor_id:str)->Tuple[str,str]:
md = build_tests_markdown(elt)
html = nbdoc.md2html(md).replace('\n','') # nbconverter fails to parse markdown if it has both html and '\n'
anchor_id = anchor_id.replace('.', '-') + '-pytest'
link, body = get_pytest_card(html, anchor_id)
return link, body
def get_pytest_card(html, anchor_id):
"creates a collapsible bootstrap card for `show_test`"
link = f'[test]'
body = (f'')
return link, body
def lookup_db(elt)->List[Dict]:
"Finds `this_test` entries from test_registry.json"
db_file = Path(abspath(join(dirname( __file__ ), '..')))/DB_NAME
if not db_file.exists():
raise Exception(f'Could not find {db_file}. Please make sure it exists at "{db_file}" or run `make test`')
with open(db_file, 'r') as f:
db = json.load(f)
key = get_func_fq_name(elt)
return db.get(key, [])
def find_related_tests(elt)->Tuple[List[Dict],List[Dict]]:
"Searches `fastai/tests` folder for any test functions related to `elt`"
related_matches = []
for test_file in find_test_files(elt):
fuzzy_matches = find_test_matches(elt, test_file)
related_matches.extend(fuzzy_matches)
return related_matches
def get_tests_dir(elt)->Path:
"Absolute path of `fastai/tests` directory"
test_dir = Path(__file__).parent.parent.parent.resolve()/'tests'
if not test_dir.exists(): raise OSError('Could not find test directory at this location:', test_dir)
return test_dir
def get_file(elt)->str:
if hasattr(elt, '__wrapped__'): elt = elt.__wrapped__
if not nbdoc.is_fastai_class(elt): return None
return inspect.getfile(elt)
def find_test_files(elt, exact_match:bool=False)->List[Path]:
"Searches in `fastai/tests` directory for module tests"
test_dir = get_tests_dir(elt)
matches = [test_dir/o.name for o in os.scandir(test_dir) if _is_file_match(elt, o.name)]
# if len(matches) != 1: raise Error('Could not find exact file match:', matches)
return matches
def _is_file_match(elt, file_name:str, exact_match:bool=False)->bool:
fp = get_file(elt)
if fp is None: return False
subdir = ifnone(_submodule_name(elt), '')
exact_re = '' if exact_match else '\w*'
return re.match(f'test_{subdir}\w*{Path(fp).stem}{exact_re}\.py', file_name)
def _submodule_name(elt)->str:
"Returns submodule - utils, text, vision, imports, etc."
if inspect.ismodule(elt): return None
modules = elt.__module__.split('.')
if len(modules) > 2:
return modules[1]
return None
def find_test_matches(elt, test_file:Path)->Tuple[List[Dict],List[Dict]]:
"Find all functions in `test_file` related to `elt`"
lines = get_lines(test_file)
rel_path = relative_test_path(test_file)
fn_name = get_qualname(elt) if not inspect.ismodule(elt) else ''
return fuzzy_test_match(fn_name, lines, rel_path)
def get_qualname(elt):
return elt.__qualname__ if hasattr(elt, '__qualname__') else fn_name(elt)
def separate_comp(qualname:str):
if not isinstance(qualname, str): qualname = get_qualname(qualname)
parts = qualname.split('.')
parts[-1] = remove_underscore(parts[-1])
if len(parts) == 1: return [], parts[0]
return parts[:-1], parts[-1]
def remove_underscore(fn_name):
if fn_name and fn_name[0] == '_': return fn_name[1:] # remove private method underscore prefix
return fn_name
def fuzzy_test_match(fn_name:str, lines:List[Dict], rel_path:str)->List[TestFunctionMatch]:
"Find any lines where `fn_name` is invoked and return the parent test function"
fuzzy_line_matches = _fuzzy_line_match(fn_name, lines)
fuzzy_matches = [get_parent_func(lno, lines, ignore_missing=True) for lno,_ in fuzzy_line_matches]
fuzzy_matches = list(filter(None.__ne__, fuzzy_matches))
return [map_test(rel_path, lno, l) for lno,l in fuzzy_matches]
def _fuzzy_line_match(fn_name:str, lines)->List[TestFunctionMatch]:
"Find any lines where `fn_name` is called"
result = []
_,fn_name = separate_comp(fn_name)
for idx,line in enumerate(lines):
if re.match(f'.*[\s\.\(]{fn_name}[\.\(]', line):
result.append((idx,line))
return result
def get_lines(file:Path)->List[str]:
with open(file, 'r') as f: return f.readlines()
def map_test(test_file, line, line_text):
"Creates dictionary test format to match doctest api"
test_name = re.match(f'\s*def (test_\w*)', line_text).groups(0)[0]
return { 'file': test_file, 'line': line, 'test': test_name }
def get_links(metadata)->Tuple[str,str]:
"Returns source code link and pytest command"
return nbdoc.get_source_link(**metadata), pytest_command(**metadata)
def pytest_command(file:str, test:str, **kwargs)->str:
"Returns CLI command to run specific test function"
return f'pytest -sv {file}::{test}'
================================================
FILE: fastai/general_optimizer.py
================================================
from .torch_core import *
from torch.optim import Optimizer
import types
__all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer']
StatScope = Enum('StatScope', 'Global Group Layer Channel Weight')
@dataclass
class Statistic():
name:str
param:float=0.9 # e.g. for exp moving average
scope:StatScope=StatScope.Weight
init:float=0. # starting value
@property
def buf(self): return f'{self.name}_buffer'
def new_step(self):
"Set state when computing statistics for Global or Group"
raise NotImplementedError
def accumulate(self, val):
"Add `val` to statistic"
raise NotImplementedError
def update(self, state, param, val=None, step=None):
"Update state with accumlated, or `val` (if `Weight` or `Layer` scope)"
raise NotImplementedError
class ConstStatistic(Statistic):
@property
def buf(self): return None
def new_step(self): pass
def accumulate(self): pass
def update(self, state, param, val=None, step=None): return param
@dataclass
class CounterStat(Statistic):
def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None
@property
def buf(self): return self._buf
def new_step(self): pass
def accumulate(self, val): pass
def update(self, state, param, val=None, step=None): return state + 1
@dataclass
class AvgStatistic(Statistic):
decay:bool=False
debias:bool=False
def new_step(self): self.val,self.count = 0.,0
def accumulate(self, val):
self.count += 1
self.val += self._get_val1(val)
def _get_val1(self, val): return val.mean()
def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val)
def _get_val3(self, state, val, param):
v = val.view(val.size(0), -1).mean(1)
return state.add_(1-param, v) if self.decay else state.add_(v)
def update(self, state, param, val=None, step=None):
if self.scope == StatScope.Weight:
# `state` is a tensor
res = self._get_val2(state.mul_(param), val, param)
elif self.scope == StatScope.Channel:
# `state` is a tensor of size n_channels
res = self._get_val3(state.mul_(param), val, param)
# For everything else, `state` is a scalar
elif self.scope == StatScope.Layer: res = state*param + self._get_val1(val) * (1-param if self.decay else 1.)
elif self.count != 0: res = state*param + self.val/self.count * (1-param if self.decay else 1.)
else: return state
if self.debias and step is not None: res /= (1 - param ** step)
return res
class AvgSquare(AvgStatistic):
def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False):
super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias)
def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel()
def _get_val2(self, state, val, param):
return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val)
def _get_val3(self, state, val, param):
v = val.view(val.size(0), -1).mean(1)
return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v)
class GeneralOptimizer(Optimizer):
def __init__(self, params, stats=None, on_step:Callable=None):
defaults = {s.name:s.param for s in listify(stats) if s.name is not None}
super().__init__(params, defaults)
self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats)
self.init_stats()
if on_step is not None: self.on_step = types.MethodType(on_step, self)
def step(self, closure=None):
self.update_stats()
for i,pg in enumerate(self.param_groups):
for p in pg['params']:
if p.grad is not None: self.on_step(p, pg, i)
def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data)
def _split_stats(self, stats):
splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope]
for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope):
if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s))
return splits
def _init_stats(self, stats, data=None):
return {stat.buf: stat.init if data is None
else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None}
def init_stats(self):
self.state['global'] = self._init_stats(self.global_stats)
for i,pg in enumerate(self.param_groups):
self.state[f'group{i}'] = self._init_stats(self.group_stats)
for p in pg['params']:
self.state[p] = self._init_stats(self.layer_stats)
self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1)))
self.state[p].update(self._init_stats(self.weight_stats, p.data))
def _set_bufs(self, p, stats, pg, val=None):
d = self.state[p]
for stat in stats:
if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None))
def update_stats(self):
for stat in self.global_stats: stat.new_step()
for i,pg in enumerate(self.param_groups):
for stat in self.group_stats: stat.new_step()
for p in pg['params']:
if p.grad is not None:
for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data)
self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data)
self._set_bufs(f'group{i}', self.group_stats, pg)
self._set_bufs('global', self.global_stats, self.param_groups[0])
================================================
FILE: fastai/imports/__init__.py
================================================
from .core import *
from .torch import *
================================================
FILE: fastai/imports/core.py
================================================
import csv, gc, gzip, os, pickle, shutil, sys, warnings, yaml, io, subprocess
import math, matplotlib.pyplot as plt, numpy as np, pandas as pd, random
import scipy.stats, scipy.special
import abc, collections, hashlib, itertools, json, operator, pathlib
import mimetypes, inspect, typing, functools, importlib, weakref
import html, re, requests, tarfile, numbers, tempfile, bz2
from abc import abstractmethod, abstractproperty
from collections import abc, Counter, defaultdict, namedtuple, OrderedDict
from collections.abc import Iterable
import concurrent
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from copy import copy, deepcopy
from dataclasses import dataclass, field, InitVar
from enum import Enum, IntEnum
from functools import partial, reduce
from pdb import set_trace
from matplotlib import patches, patheffects
from numpy import array, cos, exp, log, sin, tan, tanh
from operator import attrgetter, itemgetter
from pathlib import Path
from warnings import warn
from contextlib import contextmanager
from fastprogress.fastprogress import MasterBar, ProgressBar
from matplotlib.patches import Patch
from pandas import Series, DataFrame
from io import BufferedWriter, BytesIO
import pkg_resources
pkg_resources.require("fastprogress>=0.1.19")
from fastprogress.fastprogress import master_bar, progress_bar
#for type annotations
from numbers import Number
from typing import Any, AnyStr, Callable, Collection, Dict, Hashable, Iterator, List, Mapping, NewType, Optional
from typing import Sequence, Tuple, TypeVar, Union
from types import SimpleNamespace
def try_import(module):
"Try to import `module`. Returns module's object on success, None on failure"
try: return importlib.import_module(module)
except: return None
def have_min_pkg_version(package, version):
"Check whether we have at least `version` of `package`. Returns True on success, False otherwise."
try:
pkg_resources.require(f"{package}>={version}")
return True
except:
return False
================================================
FILE: fastai/imports/torch.py
================================================
import torch, torch.nn.functional as F
from torch import ByteTensor, DoubleTensor, FloatTensor, HalfTensor, LongTensor, ShortTensor, Tensor
from torch import nn, optim, as_tensor
from torch.utils.data import BatchSampler, DataLoader, Dataset, Sampler, TensorDataset
from torch.nn.utils import weight_norm, spectral_norm
================================================
FILE: fastai/launch.py
================================================
import subprocess, torch
from fastai.script import *
@call_parse
def main(
gpus:Param("The GPUs to use for distributed training", str)='all',
script:Param("Script to run", str, opt=False)='',
args:Param("Args to pass to script", nargs='...', opt=False)=''
):
"PyTorch distributed training launch helper that spawns multiple distributed processes"
# Loosely based on torch.distributed.launch
current_env = os.environ.copy()
gpus = list(range(torch.cuda.device_count())) if gpus=='all' else list(gpus)
current_env["WORLD_SIZE"] = str(len(gpus))
current_env["MASTER_ADDR"] = '127.0.0.1'
current_env["MASTER_PORT"] = '29500'
processes = []
for i,gpu in enumerate(gpus):
current_env["RANK"] = str(i)
cmd = [sys.executable, "-u", script, f"--gpu={gpu}"] + args
process = subprocess.Popen(cmd, env=current_env)
processes.append(process)
for process in processes: process.wait()
================================================
FILE: fastai/layers.py
================================================
"`fastai.layers` provides essential functions to building and modifying `model` architectures"
from .torch_core import *
__all__ = ['AdaptiveConcatPool2d', 'BCEWithLogitsFlat', 'BCEFlat', 'MSELossFlat', 'CrossEntropyFlat', 'Debugger',
'Flatten', 'Lambda', 'PoolFlatten', 'View', 'ResizeBatch', 'bn_drop_lin', 'conv2d', 'conv2d_trans', 'conv_layer',
'embedding', 'simple_cnn', 'NormType', 'relu', 'batchnorm_2d', 'trunc_normal_', 'PixelShuffle_ICNR', 'icnr',
'NoopLoss', 'WassersteinLoss', 'SelfAttention', 'SequentialEx', 'MergeLayer', 'res_block', 'sigmoid_range',
'SigmoidRange', 'PartialLayer', 'FlattenedLoss', 'BatchNorm1dFlat', 'LabelSmoothingCrossEntropy', 'PooledSelfAttention2d']
class Lambda(Module):
"Create a layer that simply calls `func` with `x`"
def __init__(self, func:LambdaFunc): self.func=func
def forward(self, x): return self.func(x)
class View(Module):
"Reshape `x` to `size`"
def __init__(self, *size:int): self.size = size
def forward(self, x): return x.view(self.size)
class ResizeBatch(Module):
"Reshape `x` to `size`, keeping batch dim the same size"
def __init__(self, *size:int): self.size = size
def forward(self, x): return x.view((x.size(0),) + self.size)
class Flatten(Module):
"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
def __init__(self, full:bool=False): self.full = full
def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)
def PoolFlatten()->nn.Sequential:
"Apply `nn.AdaptiveAvgPool2d` to `x` and then flatten the result."
return nn.Sequential(nn.AdaptiveAvgPool2d(1), Flatten())
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Group Instance SpectralGN')
def batchnorm_2d(nf:int, norm_type:NormType=NormType.Batch):
"A batchnorm2d layer with `nf` features initialized depending on `norm_type`."
bn = nn.BatchNorm2d(nf)
with torch.no_grad():
bn.bias.fill_(1e-3)
bn.weight.fill_(0. if norm_type==NormType.BatchZero else 1.)
return bn
def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0., actn:Optional[nn.Module]=None):
"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
layers = [nn.BatchNorm1d(n_in)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
layers.append(nn.Linear(n_in, n_out))
if actn is not None: layers.append(actn)
return layers
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias: conv.bias.data.zero_()
return spectral_norm(conv)
class PooledSelfAttention2d(Module):
"Pooled self attention layer for 2d."
def __init__(self, n_channels:int):
self.n_channels = n_channels
self.theta = spectral_norm(conv2d(n_channels, n_channels//8, 1)) # query
self.phi = spectral_norm(conv2d(n_channels, n_channels//8, 1)) # key
self.g = spectral_norm(conv2d(n_channels, n_channels//2, 1)) # value
self.o = spectral_norm(conv2d(n_channels//2, n_channels, 1))
self.gamma = nn.Parameter(tensor([0.]))
def forward(self, x):
# code borrowed from https://github.com/ajbrock/BigGAN-PyTorch/blob/7b65e82d058bfe035fc4e299f322a1f83993e04c/layers.py#L156
theta = self.theta(x)
phi = F.max_pool2d(self.phi(x), [2,2])
g = F.max_pool2d(self.g(x), [2,2])
theta = theta.view(-1, self.n_channels // 8, x.shape[2] * x.shape[3])
phi = phi.view(-1, self.n_channels // 8, x.shape[2] * x.shape[3] // 4)
g = g.view(-1, self.n_channels // 2, x.shape[2] * x.shape[3] // 4)
beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.n_channels // 2, x.shape[2], x.shape[3]))
return self.gamma * o + x
class SelfAttention(Module):
"Self attention layer for nd."
def __init__(self, n_channels:int):
self.query = conv1d(n_channels, n_channels//8)
self.key = conv1d(n_channels, n_channels//8)
self.value = conv1d(n_channels, n_channels)
self.gamma = nn.Parameter(tensor([0.]))
def forward(self, x):
#Notation from https://arxiv.org/pdf/1805.08318.pdf
size = x.size()
x = x.view(*size[:2],-1)
f,g,h = self.query(x),self.key(x),self.value(x)
beta = F.softmax(torch.bmm(f.permute(0,2,1).contiguous(), g), dim=1)
o = self.gamma * torch.bmm(h, beta) + x
return o.view(*size).contiguous()
def conv2d(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias=False, init:LayerFunc=nn.init.kaiming_normal_) -> nn.Conv2d:
"Create and initialize `nn.Conv2d` layer. `padding` defaults to `ks//2`."
if padding is None: padding = ks//2
return init_default(nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias), init)
def conv2d_trans(ni:int, nf:int, ks:int=2, stride:int=2, padding:int=0, bias=False) -> nn.ConvTranspose2d:
"Create `nn.ConvTranspose2d` layer."
return nn.ConvTranspose2d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias)
def relu(inplace:bool=False, leaky:float=None):
"Return a relu activation, maybe `leaky` and `inplace`."
return nn.LeakyReLU(inplace=inplace, negative_slope=leaky) if leaky is not None else nn.ReLU(inplace=inplace)
def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False,
norm_type:Optional[NormType]=NormType.Batch, use_activ:bool=True, leaky:float=None,
transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False):
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
if padding is None: padding = (ks-1)//2 if not transpose else 0
bn = norm_type in (NormType.Batch, NormType.BatchZero)
if bias is None: bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
if norm_type==NormType.Weight: conv = weight_norm(conv)
elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
layers = [conv]
if use_activ: layers.append(relu(True, leaky=leaky))
if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
if self_attention: layers.append(SelfAttention(nf))
return nn.Sequential(*layers)
class SequentialEx(Module):
"Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
def __init__(self, *layers): self.layers = nn.ModuleList(layers)
def forward(self, x):
res = x
for l in self.layers:
res.orig = x
nres = l(res)
#print(l. + ' mean: ' + str(nres.abs().mean()))
#print(' max: ' + str(nres.abs().max()))
# We have to remove res.orig to avoid hanging refs and therefore memory leaks
res.orig = None
res = nres
return res
def __getitem__(self,i): return self.layers[i]
def append(self,l): return self.layers.append(l)
def extend(self,l): return self.layers.extend(l)
def insert(self,i,l): return self.layers.insert(i,l)
class MergeLayer(Module):
"Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`."
def __init__(self, dense:bool=False): self.dense=dense
def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)
def res_block(nf, dense:bool=False, norm_type:Optional[NormType]=NormType.Batch, bottle:bool=False, **conv_kwargs):
"Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`."
norm2 = norm_type
if not dense and (norm_type==NormType.Batch): norm2 = NormType.BatchZero
nf_inner = nf//2 if bottle else nf
return SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs),
conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs),
MergeLayer(dense))
def sigmoid_range(x:Tensor, low:int, high:int):
"Sigmoid function with range `(low, high)`"
return torch.sigmoid(x) * (high - low) + low
class SigmoidRange(Module):
"Sigmoid module with range `(low,x_max)`"
def __init__(self, low:int, high:int): self.low,self.high = low,high
def forward(self, x): return sigmoid_range(x, self.low, self.high)
class PartialLayer(Module):
"Layer that applies `partial(func, **kwargs)`."
def __init__(self, func, **kwargs): self.repr,self.func = f'{func}({kwargs})', partial(func, **kwargs)
def forward(self, x): return self.func(x)
def __repr__(self): return self.repr
class AdaptiveConcatPool2d(Module):
"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`."
def __init__(self, sz:Optional[int]=None):
"Output will be 2*sz or 2 if sz is None"
self.output_size = sz or 1
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
class Debugger(Module):
"A module to debug inside a model."
def forward(self,x:Tensor) -> Tensor:
set_trace()
return x
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
"ICNR init of `x`, with `scale` and `init` function."
ni,nf,h,w = x.shape
ni2 = int(ni/(scale**2))
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
k = k.contiguous().view(ni2, nf, -1)
k = k.repeat(1, 1, scale**2)
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
x.data.copy_(k)
class PixelShuffle_ICNR(Module):
"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, norm_type=NormType.Weight, leaky:float=None):
nf = ifnone(nf, ni)
self.conv = conv_layer(ni, nf*(scale**2), ks=1, norm_type=norm_type, use_activ=False)
icnr(self.conv[0].weight)
self.shuf = nn.PixelShuffle(scale)
# Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658
self.pad = nn.ReplicationPad2d((1,0,1,0))
self.blur = nn.AvgPool2d(2, stride=1)
self.relu = relu(True, leaky=leaky)
def forward(self,x):
x = self.shuf(self.relu(self.conv(x)))
return self.blur(self.pad(x)) if self.blur else x
class FlattenedLoss():
"Same as `func`, but flattens input and target."
def __init__(self, func, *args, axis:int=-1, floatify:bool=False, is_2d:bool=True, **kwargs):
self.func,self.axis,self.floatify,self.is_2d = func(*args,**kwargs),axis,floatify,is_2d
functools.update_wrapper(self, self.func)
def __repr__(self): return f"FlattenedLoss of {self.func}"
@property
def reduction(self): return self.func.reduction
@reduction.setter
def reduction(self, v): self.func.reduction = v
def __call__(self, input:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
input = input.transpose(self.axis,-1).contiguous()
target = target.transpose(self.axis,-1).contiguous()
if self.floatify: target = target.float()
input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
return self.func.__call__(input, target.view(-1), **kwargs)
def CrossEntropyFlat(*args, axis:int=-1, **kwargs):
"Same as `nn.CrossEntropyLoss`, but flattens input and target."
return FlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
def BCEWithLogitsFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):
"Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
return FlattenedLoss(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
def BCEFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):
"Same as `nn.BCELoss`, but flattens input and target."
return FlattenedLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
def MSELossFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):
"Same as `nn.MSELoss`, but flattens input and target."
return FlattenedLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
class NoopLoss(Module):
"Just returns the mean of the `output`."
def forward(self, output, *args): return output.mean()
class WassersteinLoss(Module):
"For WGAN."
def forward(self, real, fake): return real.mean() - fake.mean()
def simple_cnn(actns:Collection[int], kernel_szs:Collection[int]=None,
strides:Collection[int]=None, bn=False) -> nn.Sequential:
"CNN with `conv_layer` defined by `actns`, `kernel_szs` and `strides`, plus batchnorm if `bn`."
nl = len(actns)-1
kernel_szs = ifnone(kernel_szs, [3]*nl)
strides = ifnone(strides , [2]*nl)
layers = [conv_layer(actns[i], actns[i+1], kernel_szs[i], stride=strides[i],
norm_type=(NormType.Batch if bn and i<(len(strides)-1) else None)) for i in range_of(strides)]
layers.append(PoolFlatten())
return nn.Sequential(*layers)
def trunc_normal_(x:Tensor, mean:float=0., std:float=1.) -> Tensor:
"Truncated normal initialization."
# From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12
return x.normal_().fmod_(2).mul_(std).add_(mean)
def embedding(ni:int,nf:int) -> nn.Module:
"Create an embedding layer."
emb = nn.Embedding(ni, nf)
# See https://arxiv.org/abs/1711.09160
with torch.no_grad(): trunc_normal_(emb.weight, std=0.01)
return emb
class BatchNorm1dFlat(nn.BatchNorm1d):
"`nn.BatchNorm1d`, but first flattens leading dimensions"
def forward(self, x):
if x.dim()==2: return super().forward(x)
*f,l = x.shape
x = x.contiguous().view(-1,l)
return super().forward(x).view(*f,l)
class LabelSmoothingCrossEntropy(Module):
def __init__(self, eps:float=0.1, reduction='mean'): self.eps,self.reduction = eps,reduction
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=='sum': loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=='mean': loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)
================================================
FILE: fastai/metrics.py
================================================
"Implements various metrics to measure training accuracy"
from .torch_core import *
from .callback import *
from .layers import *
from .basic_train import LearnerCallback
__all__ = ['error_rate', 'accuracy', 'accuracy_thresh', 'dice', 'exp_rmspe', 'fbeta','FBeta', 'mse', 'mean_squared_error',
'mae', 'mean_absolute_error', 'rmse', 'root_mean_squared_error', 'msle', 'mean_squared_logarithmic_error',
'explained_variance', 'r2_score', 'top_k_accuracy', 'KappaScore', 'ConfusionMatrix', 'MatthewsCorreff',
'Precision', 'Recall', 'R2Score', 'ExplainedVariance', 'ExpRMSPE', 'RMSE', 'Perplexity', 'AUROC', 'auc_roc_score',
'roc_curve', 'MultiLabelFbeta', 'foreground_acc']
def fbeta(y_pred:Tensor, y_true:Tensor, thresh:float=0.2, beta:float=2, eps:float=1e-9, sigmoid:bool=True)->Rank0Tensor:
"Computes the f_beta between `preds` and `targets`"
beta2 = beta ** 2
if sigmoid: y_pred = y_pred.sigmoid()
y_pred = (y_pred>thresh).float()
y_true = y_true.float()
TP = (y_pred*y_true).sum(dim=1)
prec = TP/(y_pred.sum(dim=1)+eps)
rec = TP/(y_true.sum(dim=1)+eps)
res = (prec*rec)/(prec*beta2+rec+eps)*(1+beta2)
return res.mean()
def accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:
"Computes accuracy with `targs` when `input` is bs * n_classes."
n = targs.shape[0]
input = input.argmax(dim=-1).view(n,-1)
targs = targs.view(n,-1)
return (input==targs).float().mean()
def accuracy_thresh(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
"Computes accuracy when `y_pred` and `y_true` are the same size."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh)==y_true.byte()).float().mean()
def top_k_accuracy(input:Tensor, targs:Tensor, k:int=5)->Rank0Tensor:
"Computes the Top-k accuracy (target is in the top k predictions)."
input = input.topk(k=k, dim=-1)[1]
targs = targs.unsqueeze(dim=-1).expand_as(input)
return (input == targs).max(dim=-1)[0].float().mean()
def foreground_acc(input, target, void_code):
"Computes non-background accuracy, e.g. camvid for multiclass segmentation"
target = target.squeeze(1)
mask = target != void_code
return (input.argmax(dim=1)[mask]==target[mask]).float().mean()
def error_rate(input:Tensor, targs:Tensor)->Rank0Tensor:
"1 - `accuracy`"
return 1 - accuracy(input, targs)
def dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:
"Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
n = targs.shape[0]
input = input.argmax(dim=1).view(n,-1)
targs = targs.view(n,-1)
intersect = (input * targs).sum().float()
union = (input+targs).sum().float()
if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())
else: return (intersect / (union-intersect+eps) if union > 0 else union.new([1.]).squeeze())
def psnr(input:Tensor, targs:Tensor)->Rank0Tensor:
return 10 * (1. / mean_squared_error(input, targs)).log10()
def exp_rmspe(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Exp RMSE between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
pred, targ = torch.exp(pred), torch.exp(targ)
pct_var = (targ - pred)/targ
return torch.sqrt((pct_var**2).mean())
def mean_absolute_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Mean absolute error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return torch.abs(targ - pred).mean()
def mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Mean squared error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return F.mse_loss(pred, targ)
def root_mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Root mean squared error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return torch.sqrt(F.mse_loss(pred, targ))
def mean_squared_logarithmic_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Mean squared logarithmic error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return F.mse_loss(torch.log(1 + pred), torch.log(1 + targ))
def explained_variance(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Explained variance between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
var_pct = torch.var(targ - pred) / torch.var(targ)
return 1 - var_pct
def r2_score(pred:Tensor, targ:Tensor)->Rank0Tensor:
"R2 score (coefficient of determination) between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
u = torch.sum((targ - pred) ** 2)
d = torch.sum((targ - targ.mean()) ** 2)
return 1 - u / d
class RegMetrics(Callback):
"Stores predictions and targets to perform calculations on epoch end."
def on_epoch_begin(self, **kwargs):
self.targs, self.preds = Tensor([]), Tensor([])
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
assert last_output.numel() == last_target.numel(), "Expected same numbers of elements in pred & targ"
self.preds = torch.cat((self.preds, last_output.cpu()))
self.targs = torch.cat((self.targs, last_target.cpu()))
class R2Score(RegMetrics):
"Computes the R2 score (coefficient of determination)."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, r2_score(self.preds, self.targs))
class ExplainedVariance(RegMetrics):
"Computes the explained variance."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, explained_variance(self.preds, self.targs))
class RMSE(RegMetrics):
"Computes the root mean squared error."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, root_mean_squared_error(self.preds, self.targs))
class ExpRMSPE(RegMetrics):
"Computes the exponential of the root mean square error."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, exp_rmspe(self.preds, self.targs))
# Aliases
mse = mean_squared_error
mae = mean_absolute_error
msle = mean_squared_logarithmic_error
rmse = root_mean_squared_error
class ConfusionMatrix(Callback):
"Computes the confusion matrix."
def on_train_begin(self, **kwargs):
self.n_classes = 0
def on_epoch_begin(self, **kwargs):
self.cm = None
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
preds = last_output.argmax(-1).view(-1).cpu()
targs = last_target.cpu()
if self.n_classes == 0:
self.n_classes = last_output.shape[-1]
self.x = torch.arange(0, self.n_classes)
cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32)
if self.cm is None: self.cm = cm
else: self.cm += cm
def on_epoch_end(self, **kwargs):
self.metric = self.cm
@dataclass
class CMScores(ConfusionMatrix):
"Base class for metrics which rely on the calculation of the precision and/or recall score."
average:Optional[str]="binary" # `binary`, `micro`, `macro`, `weigthed` or None
pos_label:int=1 # 0 or 1
eps:float=1e-9
def _recall(self):
rec = torch.diag(self.cm) / self.cm.sum(dim=1)
if self.average is None: return rec
else:
if self.average == "micro": weights = self._weights(avg="weighted")
else: weights = self._weights(avg=self.average)
return (rec * weights).sum()
def _precision(self):
prec = torch.diag(self.cm) / self.cm.sum(dim=0)
if self.average is None: return prec
else:
weights = self._weights(avg=self.average)
return (prec * weights).sum()
def _weights(self, avg:str):
if self.n_classes != 2 and avg == "binary":
avg = self.average = "macro"
warn("average=`binary` was selected for a non binary case. Value for average has now been set to `macro` instead.")
if avg == "binary":
if self.pos_label not in (0, 1):
self.pos_label = 1
warn("Invalid value for pos_label. It has now been set to 1.")
if self.pos_label == 1: return Tensor([0,1])
else: return Tensor([1,0])
elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()
class Recall(CMScores):
"Computes the Recall."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, self._recall())
class Precision(CMScores):
"Computes the Precision."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, self._precision())
@dataclass
class FBeta(CMScores):
"Computes the F`beta` score."
beta:float=2
def on_train_begin(self, **kwargs):
self.n_classes = 0
self.beta2 = self.beta ** 2
self.avg = self.average
if self.average != "micro": self.average = None
def on_epoch_end(self, last_metrics, **kwargs):
prec = self._precision()
rec = self._recall()
metric = (1 + self.beta2) * prec * rec / (prec * self.beta2 + rec + self.eps)
metric[metric != metric] = 0 # removing potential "nan"s
if self.avg: metric = (self._weights(avg=self.avg) * metric).sum()
return add_metrics(last_metrics, metric)
def on_train_end(self, **kwargs): self.average = self.avg
@dataclass
class KappaScore(ConfusionMatrix):
"Computes the rate of agreement (Cohens Kappa)."
weights:Optional[str]=None # None, `linear`, or `quadratic`
def on_epoch_end(self, last_metrics, **kwargs):
sum0 = self.cm.sum(dim=0)
sum1 = self.cm.sum(dim=1)
expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
if self.weights is None:
w = torch.ones((self.n_classes, self.n_classes))
w[self.x, self.x] = 0
elif self.weights == "linear" or self.weights == "quadratic":
w = torch.zeros((self.n_classes, self.n_classes))
w += torch.arange(self.n_classes, dtype=torch.float)
w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
k = torch.sum(w * self.cm) / torch.sum(w * expected)
return add_metrics(last_metrics, 1-k)
@dataclass
class MatthewsCorreff(ConfusionMatrix):
"Computes the Matthews correlation coefficient."
def on_epoch_end(self, last_metrics, **kwargs):
t_sum = self.cm.sum(dim=1)
p_sum = self.cm.sum(dim=0)
n_correct = torch.trace(self.cm)
n_samples = p_sum.sum()
cov_ytyp = n_correct * n_samples - torch.dot(t_sum, p_sum)
cov_ypyp = n_samples ** 2 - torch.dot(p_sum, p_sum)
cov_ytyt = n_samples ** 2 - torch.dot(t_sum, t_sum)
return add_metrics(last_metrics, cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp))
class Perplexity(Callback):
"Perplexity metric for language models."
def on_epoch_begin(self, **kwargs): self.loss,self.len = 0.,0
def on_batch_end(self, last_output, last_target, **kwargs):
self.loss += last_target.size(1) * CrossEntropyFlat()(last_output, last_target)
self.len += last_target.size(1)
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, torch.exp(self.loss / self.len))
def auc_roc_score(input:Tensor, targ:Tensor):
"Computes the area under the receiver operator characteristic (ROC) curve using the trapezoid method. Restricted binary classification tasks."
fpr, tpr = roc_curve(input, targ)
d = fpr[1:] - fpr[:-1]
sl1, sl2 = [slice(None)], [slice(None)]
sl1[-1], sl2[-1] = slice(1, None), slice(None, -1)
return (d * (tpr[tuple(sl1)] + tpr[tuple(sl2)]) / 2.).sum(-1)
def roc_curve(input:Tensor, targ:Tensor):
"Computes the receiver operator characteristic (ROC) curve by determining the true positive ratio (TPR) and false positive ratio (FPR) for various classification thresholds. Restricted binary classification tasks."
targ = (targ == 1)
desc_score_indices = torch.flip(input.argsort(-1), [-1])
input = input[desc_score_indices]
targ = targ[desc_score_indices]
d = input[1:] - input[:-1]
distinct_value_indices = torch.nonzero(d).transpose(0,1)[0]
threshold_idxs = torch.cat((distinct_value_indices, LongTensor([len(targ) - 1]).to(targ.device)))
tps = torch.cumsum(targ * 1, dim=-1)[threshold_idxs]
fps = (1 + threshold_idxs - tps)
if tps[0] != 0 or fps[0] != 0:
fps = torch.cat((LongTensor([0]), fps))
tps = torch.cat((LongTensor([0]), tps))
fpr, tpr = fps.float() / fps[-1], tps.float() / tps[-1]
return fpr, tpr
@dataclass
class AUROC(Callback):
"Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks."
def on_epoch_begin(self, **kwargs):
self.targs, self.preds = LongTensor([]), Tensor([])
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
last_output = F.softmax(last_output, dim=1)[:,-1]
self.preds = torch.cat((self.preds, last_output.cpu()))
self.targs = torch.cat((self.targs, last_target.cpu().long()))
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))
class MultiLabelFbeta(LearnerCallback):
"Computes the fbeta score for multilabel classification"
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
_order = -20
def __init__(self, learn, beta=2, eps=1e-15, thresh=0.3, sigmoid=True, average="micro"):
super().__init__(learn)
self.eps, self.thresh, self.sigmoid, self.average, self.beta2 = \
eps, thresh, sigmoid, average, beta**2
def on_train_begin(self, **kwargs):
self.c = self.learn.data.c
if self.average != "none": self.learn.recorder.add_metric_names([f'{self.average}_fbeta'])
else: self.learn.recorder.add_metric_names([f"fbeta_{c}" for c in self.learn.data.classes])
def on_epoch_begin(self, **kwargs):
dvc = self.learn.data.device
self.tp = torch.zeros(self.c).to(dvc)
self.total_pred = torch.zeros(self.c).to(dvc)
self.total_targ = torch.zeros(self.c).to(dvc)
def on_batch_end(self, last_output, last_target, **kwargs):
pred, targ = (last_output.sigmoid() if self.sigmoid else last_output) > self.thresh, last_target.byte()
m = pred*targ
self.tp += m.sum(0).float()
self.total_pred += pred.sum(0).float()
self.total_targ += targ.sum(0).float()
def fbeta_score(self, precision, recall):
return (1 + self.beta2)*(precision*recall)/((self.beta2*precision + recall) + self.eps)
def on_epoch_end(self, last_metrics, **kwargs):
self.total_pred += self.eps
self.total_targ += self.eps
if self.average == "micro":
precision, recall = self.tp.sum() / self.total_pred.sum(), self.tp.sum() / self.total_targ.sum()
res = self.fbeta_score(precision, recall)
elif self.average == "macro":
res = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)).mean()
elif self.average == "weighted":
scores = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ))
res = (scores*self.total_targ).sum() / self.total_targ.sum()
elif self.average == "none":
res = listify(self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)))
else:
raise Exception("Choose one of the average types: [micro, macro, weighted, none]")
return add_metrics(last_metrics, res)
================================================
FILE: fastai/script.py
================================================
import os, sys, subprocess, inspect
from dataclasses import dataclass
from typing import Any
from argparse import ArgumentParser
@dataclass
class Param():
"A parameter in a function used in `anno_parser` or `call_parse`"
help:str=None
type:type=None
opt:bool=True
action:str=None
nargs:str=None
const:str=None
choices:str=None
required:bool=None
@property
def pre(self): return '--' if self.opt else ''
@property
def kwargs(self): return {k:v for k,v in self.__dict__.items()
if v is not None and k!='opt'}
def anno_parser(func):
"Look at params (annotated with `Param`) in func and return an `ArgumentParser`"
p = ArgumentParser(description=func.__doc__)
for k,v in inspect.signature(func).parameters.items():
param = func.__annotations__.get(k, Param())
kwargs = param.kwargs
if v.default != inspect.Parameter.empty: kwargs['default'] = v.default
p.add_argument(f"{param.pre}{k}", **kwargs)
return p
def call_parse(func):
"Decorator to create a simple CLI from `func` using `anno_parser`"
name = inspect.currentframe().f_back.f_globals['__name__']
if name == "__main__":
args = anno_parser(func).parse_args()
func(**args.__dict__)
else: return func
def call_plac(f):
"Decorator to create a simple CLI from `func` using `plac`"
name = inspect.currentframe().f_back.f_globals['__name__']
if name == '__main__':
import plac
res = plac.call(f)
if callable(res): res()
else: return f
================================================
FILE: fastai/sixel.py
================================================
from .core import *
libsixel = try_import('libsixel')
def _sixel_encode(data, width, height):
s = io.BytesIO()
output = libsixel.sixel_output_new(lambda data, s: s.write(data), s)
dither = libsixel.sixel_dither_new(256)
w,h = int(width),int(height)
libsixel.sixel_dither_initialize(dither, data, w, h, libsixel.SIXEL_PIXELFORMAT_RGBA8888)
libsixel.sixel_encode(data, w, h, 1, dither, output)
return s.getvalue().decode('ascii')
def plot_sixel(fig=None):
if not libsixel:
warn("You could see this plot with `libsixel`. See https://github.com/saitoha/libsixel")
return
if fig is None: fig = plt.gcf()
fig.canvas.draw()
dpi = fig.get_dpi()
res = _sixel_encode(fig.canvas.buffer_rgba(), fig.get_figwidth()* dpi, fig.get_figheight() * dpi)
print(res)
================================================
FILE: fastai/tabular/__init__.py
================================================
from .. import basics
from ..basics import *
from .data import *
from .transform import *
from .models import *
from .. import tabular
__all__ = [*basics.__all__, *data.__all__, *transform.__all__, *models.__all__, 'tabular']
================================================
FILE: fastai/tabular/data.py
================================================
"Data loading pipeline for structured data support. Loads from pandas DataFrame"
from ..torch_core import *
from .transform import *
from ..basic_data import *
from ..data_block import *
from ..basic_train import *
from .models import *
from pandas.api.types import is_numeric_dtype, is_categorical_dtype
__all__ = ['TabularDataBunch', 'TabularLine', 'TabularList', 'TabularProcessor', 'tabular_learner']
OptTabTfms = Optional[Collection[TabularProc]]
#def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)
def emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))
def def_emb_sz(classes, n, sz_dict=None):
"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`."
sz_dict = ifnone(sz_dict, {})
n_cat = len(classes[n])
sz = sz_dict.get(n, int(emb_sz_rule(n_cat))) # rule of thumb
return n_cat,sz
class TabularLine(ItemBase):
"Basic item for tabular data."
def __init__(self, cats, conts, classes, names):
self.cats,self.conts,self.classes,self.names = cats,conts,classes,names
self.data = [tensor(cats), tensor(conts)]
def __str__(self):
res = ''
for c, n in zip(self.cats, self.names[:len(self.cats)]):
res += f"{n} {(self.classes[n][c])}; "
for c,n in zip(self.conts, self.names[len(self.cats):]):
res += f'{n} {c:.4f}; '
return res
class TabularProcessor(PreProcessor):
"Regroup the `procs` in one `PreProcessor`."
def __init__(self, ds:ItemBase=None, procs=None):
procs = ifnone(procs, ds.procs if ds is not None else None)
self.procs = listify(procs)
def process_one(self, item):
df = pd.DataFrame([item,item])
for proc in self.procs: proc(df, test=True)
if len(self.cat_names) != 0:
codes = np.stack([c.cat.codes.values for n,c in df[self.cat_names].items()], 1).astype(np.int64) + 1
else: codes = [[]]
if len(self.cont_names) != 0:
conts = np.stack([c.astype('float32').values for n,c in df[self.cont_names].items()], 1)
else: conts = [[]]
classes = None
col_names = list(df[self.cat_names].columns.values) + list(df[self.cont_names].columns.values)
return TabularLine(codes[0], conts[0], classes, col_names)
def process(self, ds):
if ds.inner_df is None:
ds.classes,ds.cat_names,ds.cont_names = self.classes,self.cat_names,self.cont_names
ds.col_names = self.cat_names + self.cont_names
ds.preprocessed = True
return
for i,proc in enumerate(self.procs):
if isinstance(proc, TabularProc): proc(ds.inner_df, test=True)
else:
#cat and cont names may have been changed by transform (like Fill_NA)
proc = proc(ds.cat_names, ds.cont_names)
proc(ds.inner_df)
ds.cat_names,ds.cont_names = proc.cat_names,proc.cont_names
self.procs[i] = proc
self.cat_names,self.cont_names = ds.cat_names,ds.cont_names
if len(ds.cat_names) != 0:
ds.codes = np.stack([c.cat.codes.values for n,c in ds.inner_df[ds.cat_names].items()], 1).astype(np.int64) + 1
self.classes = ds.classes = OrderedDict({n:np.concatenate([['#na#'],c.cat.categories.values])
for n,c in ds.inner_df[ds.cat_names].items()})
cat_cols = list(ds.inner_df[ds.cat_names].columns.values)
else: ds.codes,ds.classes,self.classes,cat_cols = None,None,None,[]
if len(ds.cont_names) != 0:
ds.conts = np.stack([c.astype('float32').values for n,c in ds.inner_df[ds.cont_names].items()], 1)
cont_cols = list(ds.inner_df[ds.cont_names].columns.values)
else: ds.conts,cont_cols = None,[]
ds.col_names = cat_cols + cont_cols
ds.preprocessed = True
class TabularDataBunch(DataBunch):
"Create a `DataBunch` suitable for tabular data."
@classmethod
def from_df(cls, path, df:DataFrame, dep_var:str, valid_idx:Collection[int], procs:OptTabTfms=None,
cat_names:OptStrList=None, cont_names:OptStrList=None, classes:Collection=None,
test_df=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False)->DataBunch:
"Create a `DataBunch` from `df` and `valid_idx` with `dep_var`. `kwargs` are passed to `DataBunch.create`."
cat_names = ifnone(cat_names, []).copy()
cont_names = ifnone(cont_names, list(set(df)-set(cat_names)-{dep_var}))
procs = listify(procs)
src = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)
.split_by_idx(valid_idx))
src = src.label_from_df(cols=dep_var) if classes is None else src.label_from_df(cols=dep_var, classes=classes)
if test_df is not None: src.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names,
processor = src.train.x.processor))
return src.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, device=device,
collate_fn=collate_fn, no_check=no_check)
class TabularList(ItemList):
"Basic `ItemList` for tabular data."
_item_cls=TabularLine
_processor=TabularProcessor
_bunch=TabularDataBunch
def __init__(self, items:Iterator, cat_names:OptStrList=None, cont_names:OptStrList=None,
procs=None, **kwargs)->'TabularList':
super().__init__(range_of(items), **kwargs)
#dataframe is in inner_df, items is just a range of index
if cat_names is None: cat_names = []
if cont_names is None: cont_names = []
self.cat_names,self.cont_names,self.procs = cat_names,cont_names,procs
self.copy_new += ['cat_names', 'cont_names', 'procs']
self.preprocessed = False
@classmethod
def from_df(cls, df:DataFrame, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, **kwargs)->'ItemList':
"Get the list of inputs in the `col` of `path/csv_name`."
return cls(items=range(len(df)), cat_names=cat_names, cont_names=cont_names, procs=procs, inner_df=df.copy(), **kwargs)
def get(self, o):
if not self.preprocessed: return self.inner_df.iloc[o] if hasattr(self, 'inner_df') else self.items[o]
codes = [] if self.codes is None else self.codes[o]
conts = [] if self.conts is None else self.conts[o]
return self._item_cls(codes, conts, self.classes, self.col_names)
def get_emb_szs(self, sz_dict=None):
"Return the default embedding sizes suitable for this data or takes the ones in `sz_dict`."
return [def_emb_sz(self.classes, n, sz_dict) for n in self.cat_names]
def reconstruct(self, t:Tensor):
return self._item_cls(t[0], t[1], self.classes, self.col_names)
def show_xys(self, xs, ys)->None:
"Show the `xs` (inputs) and `ys` (targets)."
from IPython.display import display, HTML
items,names = [], xs[0].names + ['target']
for i, (x,y) in enumerate(zip(xs,ys)):
res = []
cats = x.cats if len(x.cats.size()) > 0 else []
conts = x.conts if len(x.conts.size()) > 0 else []
for c, n in zip(cats, x.names[:len(cats)]):
res.append(x.classes[n][c])
res += [f'{c:.4f}' for c in conts] + [y]
items.append(res)
items = np.array(items)
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
def show_xyzs(self, xs, ys, zs):
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions)."
from IPython.display import display, HTML
items,names = [], xs[0].names + ['target', 'prediction']
for i, (x,y,z) in enumerate(zip(xs,ys,zs)):
res = []
cats = x.cats if len(x.cats.size()) > 0 else []
conts = x.conts if len(x.conts.size()) > 0 else []
for c, n in zip(cats, x.names[:len(cats)]):
res.append(str(x.classes[n][c]))
res += [f'{c:.4f}' for c in conts] + [y, z]
items.append(res)
items = np.array(items)
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
def tabular_learner(data:DataBunch, layers:Collection[int], emb_szs:Dict[str,int]=None, metrics=None,
ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, **learn_kwargs):
"Get a `Learner` using `data`, with `metrics`, including a `TabularModel` created using the remaining params."
emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))
model = TabularModel(emb_szs, len(data.cont_names), out_sz=data.c, layers=layers, ps=ps, emb_drop=emb_drop,
y_range=y_range, use_bn=use_bn)
return Learner(data, model, metrics=metrics, **learn_kwargs)
================================================
FILE: fastai/tabular/models.py
================================================
from ..torch_core import *
from ..layers import *
from ..basic_data import *
from ..basic_train import *
from ..train import ClassificationInterpretation
__all__ = ['TabularModel']
class TabularModel(Module):
"Basic model for tabular data."
def __init__(self, emb_szs:ListSizes, n_cont:int, out_sz:int, layers:Collection[int], ps:Collection[float]=None,
emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):
super().__init__()
ps = ifnone(ps, [0]*len(layers))
ps = listify(ps, layers)
self.embeds = nn.ModuleList([embedding(ni, nf) for ni,nf in emb_szs])
self.emb_drop = nn.Dropout(emb_drop)
self.bn_cont = nn.BatchNorm1d(n_cont)
n_emb = sum(e.embedding_dim for e in self.embeds)
self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range
sizes = self.get_sizes(layers, out_sz)
actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
layers = []
for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):
layers += bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)
if bn_final: layers.append(nn.BatchNorm1d(sizes[-1]))
self.layers = nn.Sequential(*layers)
def get_sizes(self, layers, out_sz):
return [self.n_emb + self.n_cont] + layers + [out_sz]
def forward(self, x_cat:Tensor, x_cont:Tensor) -> Tensor:
if self.n_emb != 0:
x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
x = torch.cat(x, 1)
x = self.emb_drop(x)
if self.n_cont != 0:
x_cont = self.bn_cont(x_cont)
x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
x = self.layers(x)
if self.y_range is not None:
x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]
return x
@classmethod
def _cl_int_from_learner(cls, learn:Learner, ds_type=DatasetType.Valid, activ:nn.Module=None):
"Creates an instance of 'ClassificationInterpretation"
preds = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)
return cls(learn, *preds, ds_type=ds_type)
def _cl_int_plot_top_losses(self, k, largest:bool=True, return_table:bool=False)->Optional[plt.Figure]:
"Generates a dataframe of 'top_losses' along with their prediction, actual, loss, and probability of the actual class."
tl_val, tl_idx = self.top_losses(k, largest)
classes = self.data.classes
cat_names = self.data.x.cat_names
cont_names = self.data.x.cont_names
df = pd.DataFrame(columns=[['Prediction', 'Actual', 'Loss', 'Probability'] + cat_names + cont_names])
for i, idx in enumerate(tl_idx):
da, cl = self.data.dl(self.ds_type).dataset[idx]
cl = int(cl)
t1 = str(da)
t1 = t1.split(';')
arr = []
arr.extend([classes[self.pred_class[idx]], classes[cl], f'{self.losses[idx]:.2f}',
f'{self.preds[idx][cl]:.2f}'])
for x in range(len(t1)-1):
_, value = t1[x].rsplit(' ', 1)
arr.append(value)
df.loc[i] = arr
display(df)
return_fig = return_table
if ifnone(return_fig, defaults.return_fig): return df
ClassificationInterpretation.from_learner = _cl_int_from_learner
ClassificationInterpretation.plot_top_losses = _cl_int_plot_top_losses
def _learner_interpret(learn:Learner, ds_type:DatasetType = DatasetType.Valid):
"Create a 'ClassificationInterpretation' object from 'learner' on 'ds_type'."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
Learner.interpret = _learner_interpret
================================================
FILE: fastai/tabular/transform.py
================================================
"Cleaning and feature engineering functions for structured data"
from ..torch_core import *
from pandas.api.types import is_numeric_dtype
from datetime import date, datetime
import calendar
__all__ = ['add_datepart', 'cont_cat_split', 'Categorify', 'FillMissing', 'FillStrategy', 'Normalize', 'TabularProc',
'add_elapsed_times', 'make_date', 'add_cyclic_datepart']
def make_date(df:DataFrame, date_field:str):
"Make sure `df[field_name]` is of the right date type."
field_dtype = df[date_field].dtype
if isinstance(field_dtype, pd.core.dtypes.dtypes.DatetimeTZDtype):
field_dtype = np.datetime64
if not np.issubdtype(field_dtype, np.datetime64):
df[date_field] = pd.to_datetime(df[date_field], infer_datetime_format=True)
def cyclic_dt_feat_names(time:bool=True, add_linear:bool=False)->List[str]:
"Return feature names of date/time cycles as produced by `cyclic_dt_features`."
fs = ['cos','sin']
attr = [f'{r}_{f}' for r in 'weekday day_month month_year day_year'.split() for f in fs]
if time: attr += [f'{r}_{f}' for r in 'hour clock min sec'.split() for f in fs]
if add_linear: attr.append('year_lin')
return attr
def cyclic_dt_features(d:Union[date,datetime], time:bool=True, add_linear:bool=False)->List[float]:
"Calculate the cos and sin of date/time cycles."
tt,fs = d.timetuple(), [np.cos, np.sin]
day_year,days_month = tt.tm_yday, calendar.monthrange(d.year, d.month)[1]
days_year = 366 if calendar.isleap(d.year) else 365
rs = d.weekday()/7, (d.day-1)/days_month, (d.month-1)/12, (day_year-1)/days_year
feats = [f(r * 2 * np.pi) for r in rs for f in fs]
if time and isinstance(d, datetime) and type(d) != date:
rs = tt.tm_hour/24, tt.tm_hour%12/12, tt.tm_min/60, tt.tm_sec/60
feats += [f(r * 2 * np.pi) for r in rs for f in fs]
if add_linear:
if type(d) == date: feats.append(d.year + rs[-1])
else:
secs_in_year = (datetime(d.year+1, 1, 1) - datetime(d.year, 1, 1)).total_seconds()
feats.append(d.year + ((d - datetime(d.year, 1, 1)).total_seconds() / secs_in_year))
return feats
def add_cyclic_datepart(df:DataFrame, field_name:str, prefix:str=None, drop:bool=True, time:bool=False, add_linear:bool=False):
"Helper function that adds trigonometric date/time features to a date in the column `field_name` of `df`."
make_date(df, field_name)
field = df[field_name]
prefix = ifnone(prefix, re.sub('[Dd]ate$', '', field_name))
series = field.apply(partial(cyclic_dt_features, time=time, add_linear=add_linear))
columns = [prefix + c for c in cyclic_dt_feat_names(time, add_linear)]
df_feats = pd.DataFrame([item for item in series], columns=columns, index=series.index)
for column in columns: df[column] = df_feats[column]
if drop: df.drop(field_name, axis=1, inplace=True)
return df
def add_datepart(df:DataFrame, field_name:str, prefix:str=None, drop:bool=True, time:bool=False):
"Helper function that adds columns relevant to a date in the column `field_name` of `df`."
make_date(df, field_name)
field = df[field_name]
prefix = ifnone(prefix, re.sub('[Dd]ate$', '', field_name))
attr = ['Year', 'Month', 'Week', 'Day', 'Dayofweek', 'Dayofyear', 'Is_month_end', 'Is_month_start',
'Is_quarter_end', 'Is_quarter_start', 'Is_year_end', 'Is_year_start']
if time: attr = attr + ['Hour', 'Minute', 'Second']
for n in attr: df[prefix + n] = getattr(field.dt, n.lower())
df[prefix + 'Elapsed'] = field.astype(np.int64) // 10 ** 9
if drop: df.drop(field_name, axis=1, inplace=True)
return df
def _get_elapsed(df:DataFrame,field_names:Collection[str], date_field:str, base_field:str, prefix:str):
for f in field_names:
day1 = np.timedelta64(1, 'D')
last_date,last_base,res = np.datetime64(),None,[]
for b,v,d in zip(df[base_field].values, df[f].values, df[date_field].values):
if last_base is None or b != last_base:
last_date,last_base = np.datetime64(),b
if v: last_date = d
res.append(((d-last_date).astype('timedelta64[D]') / day1))
df[prefix + f] = res
return df
def add_elapsed_times(df:DataFrame, field_names:Collection[str], date_field:str, base_field:str):
field_names = listify(field_names)
#Make sure date_field is a date and base_field a bool
df[field_names] = df[field_names].astype('bool')
make_date(df, date_field)
work_df = df[field_names + [date_field, base_field]]
work_df = work_df.sort_values([base_field, date_field])
work_df = _get_elapsed(work_df, field_names, date_field, base_field, 'After')
work_df = work_df.sort_values([base_field, date_field], ascending=[True, False])
work_df = _get_elapsed(work_df, field_names, date_field, base_field, 'Before')
for a in ['After' + f for f in field_names] + ['Before' + f for f in field_names]:
work_df[a] = work_df[a].fillna(0).astype(int)
for a,s in zip([True, False], ['_bw', '_fw']):
work_df = work_df.set_index(date_field)
tmp = (work_df[[base_field] + field_names].sort_index(ascending=a)
.groupby(base_field).rolling(7, min_periods=1).sum())
tmp.drop(base_field,1,inplace=True)
tmp.reset_index(inplace=True)
work_df.reset_index(inplace=True)
work_df = work_df.merge(tmp, 'left', [date_field, base_field], suffixes=['', s])
work_df.drop(field_names,1,inplace=True)
return df.merge(work_df, 'left', [date_field, base_field])
def cont_cat_split(df, max_card=20, dep_var=None)->Tuple[List,List]:
"Helper function that returns column names of cont and cat variables from given df."
cont_names, cat_names = [], []
for label in df:
if label == dep_var: continue
if df[label].dtype == int and df[label].unique().shape[0] > max_card or df[label].dtype == float: cont_names.append(label)
else: cat_names.append(label)
return cont_names, cat_names
@dataclass
class TabularProc():
"A processor for tabular dataframes."
cat_names:StrList
cont_names:StrList
def __call__(self, df:DataFrame, test:bool=False):
"Apply the correct function to `df` depending on `test`."
func = self.apply_test if test else self.apply_train
func(df)
def apply_train(self, df:DataFrame):
"Function applied to `df` if it's the train set."
raise NotImplementedError
def apply_test(self, df:DataFrame):
"Function applied to `df` if it's the test set."
self.apply_train(df)
class Categorify(TabularProc):
"Transform the categorical variables to that type."
def apply_train(self, df:DataFrame):
"Transform `self.cat_names` columns in categorical."
self.categories = {}
for n in self.cat_names:
df.loc[:,n] = df.loc[:,n].astype('category').cat.as_ordered()
self.categories[n] = df[n].cat.categories
def apply_test(self, df:DataFrame):
"Transform `self.cat_names` columns in categorical using the codes decided in `apply_train`."
for n in self.cat_names:
df.loc[:,n] = pd.Categorical(df[n], categories=self.categories[n], ordered=True)
FillStrategy = IntEnum('FillStrategy', 'MEDIAN COMMON CONSTANT')
@dataclass
class FillMissing(TabularProc):
"Fill the missing values in continuous columns."
fill_strategy:FillStrategy=FillStrategy.MEDIAN
add_col:bool=True
fill_val:float=0.
def apply_train(self, df:DataFrame):
"Fill missing values in `self.cont_names` according to `self.fill_strategy`."
self.na_dict = {}
for name in self.cont_names:
if pd.isnull(df[name]).sum():
if self.add_col:
df[name+'_na'] = pd.isnull(df[name])
if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')
if self.fill_strategy == FillStrategy.MEDIAN: filler = df[name].median()
elif self.fill_strategy == FillStrategy.CONSTANT: filler = self.fill_val
else: filler = df[name].dropna().value_counts().idxmax()
df[name] = df[name].fillna(filler)
self.na_dict[name] = filler
def apply_test(self, df:DataFrame):
"Fill missing values in `self.cont_names` like in `apply_train`."
for name in self.cont_names:
if name in self.na_dict:
if self.add_col:
df[name+'_na'] = pd.isnull(df[name])
if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')
df[name] = df[name].fillna(self.na_dict[name])
elif pd.isnull(df[name]).sum() != 0:
raise Exception(f"""There are nan values in field {name} but there were none in the training set.
Please fix those manually.""")
class Normalize(TabularProc):
"Normalize the continuous variables."
def apply_train(self, df:DataFrame):
"Compute the means and stds of `self.cont_names` columns to normalize them."
self.means,self.stds = {},{}
for n in self.cont_names:
assert is_numeric_dtype(df[n]), (f"""Cannot normalize '{n}' column as it isn't numerical.
Are you sure it doesn't belong in the categorical set of columns?""")
self.means[n],self.stds[n] = df[n].mean(),df[n].std()
df[n] = (df[n]-self.means[n]) / (1e-7 + self.stds[n])
def apply_test(self, df:DataFrame):
"Normalize `self.cont_names` with the same statistics as in `apply_train`."
for n in self.cont_names:
df[n] = (df[n]-self.means[n]) / (1e-7 + self.stds[n])
================================================
FILE: fastai/test_registry.json
================================================
{
"fastai.basic_data.DataBunch": [
{
"file": "tests/test_data_block.py",
"line": 152,
"test": "test_custom_dataset"
}
],
"fastai.basic_data.DataBunch.create": [
{
"file": "tests/test_basic_data.py",
"line": 30,
"test": "test_DataBunch_Create"
},
{
"file": "tests/test_basic_data.py",
"line": 44,
"test": "test_DataBunch_no_valid_dl"
}
],
"fastai.basic_data.DataBunch.one_batch": [
{
"file": "tests/test_basic_data.py",
"line": 58,
"test": "test_DataBunch_onebatch"
},
{
"file": "tests/test_text_data.py",
"line": 83,
"test": "test_should_load_backwards_lm_1"
},
{
"file": "tests/test_text_data.py",
"line": 99,
"test": "test_should_load_backwards_lm_2"
},
{
"file": "tests/test_text_data.py",
"line": 110,
"test": "test_backwards_cls_databunch"
},
{
"file": "tests/test_basic_data.py",
"line": 83,
"test": "test_DataBunch_save_load"
}
],
"fastai.basic_data.DataBunch.one_item": [
{
"file": "tests/test_basic_data.py",
"line": 67,
"test": "test_DataBunch_oneitem"
}
],
"fastai.basic_data.DataBunch.save": [
{
"file": "tests/test_basic_data.py",
"line": 83,
"test": "test_DataBunch_save_load"
}
],
"fastai.basic_data.DataBunch.show_batch": [
{
"file": "tests/test_basic_data.py",
"line": 75,
"test": "test_DataBunch_show_batch"
}
],
"fastai.basic_data.intercept_args": [
{
"file": "tests/test_basic_data.py",
"line": 18,
"test": "test_intercept_args"
}
],
"fastai.basic_data.load_data": [
{
"file": "tests/test_text_data.py",
"line": 129,
"test": "test_load_and_save_test"
},
{
"file": "tests/test_basic_data.py",
"line": 83,
"test": "test_DataBunch_save_load"
}
],
"fastai.basic_train.Learner.destroy": [
{
"file": "tests/test_basic_train.py",
"line": 170,
"test": "test_destroy"
},
{
"file": "tests/test_basic_train.py",
"line": 213,
"test": "test_memory"
}
],
"fastai.basic_train.Learner.export": [
{
"file": "tests/test_basic_train.py",
"line": 230,
"test": "test_export_load_learner"
}
],
"fastai.basic_train.Learner.fit": [
{
"file": "tests/test_train.py",
"line": 28,
"test": "test_fit"
}
],
"fastai.basic_train.Learner.freeze": [
{
"file": "tests/test_basic_train.py",
"line": 49,
"test": "test_freeze"
}
],
"fastai.basic_train.Learner.freeze_to": [
{
"file": "tests/test_basic_train.py",
"line": 39,
"test": "test_freeze_to"
}
],
"fastai.basic_train.Learner.get_preds": [
{
"file": "tests/test_basic_train.py",
"line": 32,
"test": "test_get_preds"
}
],
"fastai.basic_train.Learner.load": [
{
"file": "tests/test_basic_train.py",
"line": 104,
"test": "test_save_load"
},
{
"file": "tests/test_basic_train.py",
"line": 213,
"test": "test_memory"
}
],
"fastai.basic_train.Learner.predict": [
{
"file": "tests/test_vision_train.py",
"line": 63,
"test": "test_preds"
},
{
"file": "tests/test_vision_train.py",
"line": 89,
"test": "test_models_meta"
}
],
"fastai.basic_train.Learner.purge": [
{
"file": "tests/test_basic_train.py",
"line": 76,
"test": "test_purge"
},
{
"file": "tests/test_basic_train.py",
"line": 104,
"test": "test_save_load"
},
{
"file": "tests/test_basic_train.py",
"line": 213,
"test": "test_memory"
}
],
"fastai.basic_train.Learner.save": [
{
"file": "tests/test_basic_train.py",
"line": 104,
"test": "test_save_load"
},
{
"file": "tests/test_basic_train.py",
"line": 213,
"test": "test_memory"
}
],
"fastai.basic_train.Learner.unfreeze": [
{
"file": "tests/test_basic_train.py",
"line": 58,
"test": "test_unfreeze"
}
],
"fastai.basic_train.Learner.validate": [
{
"file": "tests/test_collab_train.py",
"line": 16,
"test": "test_val_loss"
},
{
"file": "tests/test_text_train.py",
"line": 56,
"test": "test_val_loss"
}
],
"fastai.basic_train.Recorder": [
{
"file": "tests/test_vision_train.py",
"line": 49,
"test": "test_1cycle_lrs"
},
{
"file": "tests/test_vision_train.py",
"line": 56,
"test": "test_1cycle_moms"
}
],
"fastai.basic_train.load_learner": [
{
"file": "tests/test_basic_train.py",
"line": 230,
"test": "test_export_load_learner"
}
],
"fastai.basic_train.validate": [
{
"file": "tests/test_tabular_train.py",
"line": 26,
"test": "test_accuracy"
}
],
"fastai.callback.AverageMetric": [
{
"file": "tests/test_metrics.py",
"line": 213,
"test": "test_average_metric_naming"
}
],
"fastai.callback.Callback": [
{
"file": "tests/test_callback.py",
"line": 33,
"test": "test_callbacks_learner"
},
{
"file": "tests/test_callback.py",
"line": 64,
"test": "test_callbacks_fit"
}
],
"fastai.callbacks.csv_logger.CSVLogger": [
{
"file": "tests/test_callbacks_csv_logger.py",
"line": 37,
"test": "test_logger"
}
],
"fastai.callbacks.hooks.hook_output": [
{
"file": "tests/test_callbacks_hooks.py",
"line": 74,
"test": "test_hook_output_basics"
}
],
"fastai.callbacks.hooks.model_summary": [
{
"file": "tests/test_callbacks_hooks.py",
"line": 18,
"test": "test_model_summary_vision"
},
{
"file": "tests/test_callbacks_hooks.py",
"line": 26,
"test": "test_model_summary_text"
},
{
"file": "tests/test_callbacks_hooks.py",
"line": 33,
"test": "test_model_summary_tabular"
},
{
"file": "tests/test_callbacks_hooks.py",
"line": 48,
"test": "test_model_summary_collab"
},
{
"file": "tests/test_basic_train.py",
"line": 230,
"test": "test_export_load_learner"
}
],
"fastai.callbacks.mem.PeakMemMetric": [
{
"file": "tests/test_callbacks_mem.py",
"line": 8,
"test": "test_peak_mem_metric"
}
],
"fastai.callbacks.misc.StopAfterNBatches": [
{
"file": "tests/test_callbacks_misc.py",
"line": 22,
"test": "test_stop_after_n_batches"
}
],
"fastai.core.Category": [
{
"file": "tests/test_core.py",
"line": 242,
"test": "test_itembase_eq"
}
],
"fastai.core.Category.__hash__": [
{
"file": "tests/test_core.py",
"line": 304,
"test": "test_itembase_hash"
}
],
"fastai.core.FloatItem": [
{
"file": "tests/test_core.py",
"line": 242,
"test": "test_itembase_eq"
}
],
"fastai.core.FloatItem.__hash__": [
{
"file": "tests/test_core.py",
"line": 304,
"test": "test_itembase_hash"
}
],
"fastai.core.ItemBase.__eq__": [
{
"file": "tests/test_core.py",
"line": 242,
"test": "test_itembase_eq"
},
{
"file": "tests/test_core.py",
"line": 304,
"test": "test_itembase_hash"
}
],
"fastai.core.MultiCategory": [
{
"file": "tests/test_core.py",
"line": 242,
"test": "test_itembase_eq"
}
],
"fastai.core.MultiCategory.__hash__": [
{
"file": "tests/test_core.py",
"line": 304,
"test": "test_itembase_hash"
}
],
"fastai.core.arrays_split": [
{
"file": "tests/test_core.py",
"line": 141,
"test": "test_arrays_split"
}
],
"fastai.core.camel2snake": [
{
"file": "tests/test_core.py",
"line": 164,
"test": "test_camel2snake"
}
],
"fastai.core.chunks": [
{
"file": "tests/test_core.py",
"line": 46,
"test": "test_chunks"
}
],
"fastai.core.df_names_to_idx": [
{
"file": "tests/test_core.py",
"line": 213,
"test": "test_df_names_to_idx"
}
],
"fastai.core.download_url": [
{
"file": "tests/test_core.py",
"line": 193,
"test": "test_download_url"
}
],
"fastai.core.even_mults": [
{
"file": "tests/test_core.py",
"line": 178,
"test": "test_even_mults"
}
],
"fastai.core.find_classes": [
{
"file": "tests/test_core.py",
"line": 131,
"test": "test_find_classes"
}
],
"fastai.core.idx_dict": [
{
"file": "tests/test_core.py",
"line": 125,
"test": "test_idx_dict"
}
],
"fastai.core.ifnone": [
{
"file": "tests/test_core.py",
"line": 39,
"test": "test_ifnone"
}
],
"fastai.core.is1d": [
{
"file": "tests/test_core.py",
"line": 235,
"test": "test_is1d"
}
],
"fastai.core.is_dict": [
{
"file": "tests/test_core.py",
"line": 76,
"test": "test_dict"
}
],
"fastai.core.is_listy": [
{
"file": "tests/test_core.py",
"line": 59,
"test": "test_listy"
}
],
"fastai.core.is_tuple": [
{
"file": "tests/test_core.py",
"line": 70,
"test": "test_tuple"
}
],
"fastai.core.join_path": [
{
"file": "tests/test_core.py",
"line": 206,
"test": "test_join_paths"
}
],
"fastai.core.listify": [
{
"file": "tests/test_core.py",
"line": 25,
"test": "test_listify"
}
],
"fastai.core.noop": [
{
"file": "tests/test_core.py",
"line": 82,
"test": "test_noop"
}
],
"fastai.core.num_cpus": [
{
"file": "tests/test_core.py",
"line": 8,
"test": "test_cpus"
}
],
"fastai.core.one_hot": [
{
"file": "tests/test_core.py",
"line": 218,
"test": "test_one_hot"
}
],
"fastai.core.partition": [
{
"file": "tests/test_core.py",
"line": 94,
"test": "test_partition_functionality"
}
],
"fastai.core.random_split": [
{
"file": "tests/test_core.py",
"line": 154,
"test": "test_random_split"
}
],
"fastai.core.recurse": [
{
"file": "tests/test_core.py",
"line": 29,
"test": "test_recurse"
}
],
"fastai.core.series2cat": [
{
"file": "tests/test_core.py",
"line": 184,
"test": "test_series2cat"
}
],
"fastai.core.subplots": [
{
"file": "tests/test_core.py",
"line": 222,
"test": "test_subplots_multi_row_cols"
},
{
"file": "tests/test_core.py",
"line": 229,
"test": "test_subplots_single"
}
],
"fastai.core.to_int": [
{
"file": "tests/test_core.py",
"line": 86,
"test": "test_to_int"
}
],
"fastai.core.uniqueify": [
{
"file": "tests/test_core.py",
"line": 53,
"test": "test_uniqueify"
}
],
"fastai.data_block.CategoryProcessor.process_one": [
{
"file": "tests/test_data_block.py",
"line": 80,
"test": "test_category_processor_existing_class"
},
{
"file": "tests/test_data_block.py",
"line": 91,
"test": "test_category_processor_non_existing_class"
}
],
"fastai.data_block.ItemList.filter_by_folder": [
{
"file": "tests/test_data_block.py",
"line": 161,
"test": "test_filter_by_folder"
}
],
"fastai.data_block.ItemList.filter_by_rand": [
{
"file": "tests/test_data_block.py",
"line": 112,
"test": "test_filter_by_rand"
}
],
"fastai.data_block.ItemList.label_from_folder": [
{
"file": "tests/test_text_data.py",
"line": 30,
"test": "test_from_folder"
},
{
"file": "tests/test_text_data.py",
"line": 42,
"test": "test_filter_classes"
}
],
"fastai.data_block.ItemList.split_by_rand_pct": [
{
"file": "tests/test_data_block.py",
"line": 103,
"test": "test_splitdata_datasets"
}
],
"fastai.data_block.ItemList.split_subsets": [
{
"file": "tests/test_data_block.py",
"line": 121,
"test": "test_split_subsets"
}
],
"fastai.data_block.LabelLists.databunch": [
{
"file": "tests/test_vision_data.py",
"line": 217,
"test": "test_vision_datasets"
}
],
"fastai.datasets.Config": [
{
"file": "tests/test_datasets.py",
"line": 15,
"test": "test_creates_config"
},
{
"file": "tests/test_datasets.py",
"line": 26,
"test": "test_load_config"
},
{
"file": "tests/test_datasets.py",
"line": 29,
"test": "test_default_config"
},
{
"file": "tests/test_datasets.py",
"line": 42,
"test": "test_user_config"
}
],
"fastai.datasets.datapath4file": [
{
"file": "tests/test_datasets.py",
"line": 26,
"test": "test_load_config"
},
{
"file": "tests/test_datasets.py",
"line": 42,
"test": "test_user_config"
}
],
"fastai.datasets.download_data": [
{
"file": "tests/test_datasets.py",
"line": 26,
"test": "test_load_config"
},
{
"file": "tests/test_datasets.py",
"line": 42,
"test": "test_user_config"
}
],
"fastai.datasets.untar_data": [
{
"file": "tests/test_vision_data.py",
"line": 165,
"test": "test_trunc_download"
},
{
"file": "tests/test_datasets.py",
"line": 26,
"test": "test_load_config"
},
{
"file": "tests/test_datasets.py",
"line": 42,
"test": "test_user_config"
}
],
"fastai.datasets.url2path": [
{
"file": "tests/test_datasets.py",
"line": 26,
"test": "test_load_config"
},
{
"file": "tests/test_datasets.py",
"line": 42,
"test": "test_user_config"
}
],
"fastai.gen_doc.doctest.merge_registries": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 199,
"test": "test_merge_registries"
}
],
"fastai.gen_doc.doctest.this_tests": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 75,
"test": "test_this_tests"
}
],
"fastai.gen_doc.nbtest._fuzzy_line_match": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 61,
"test": "test_fuzzy_line_match"
}
],
"fastai.gen_doc.nbtest._is_file_match": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 16,
"test": "test_is_file_match"
}
],
"fastai.gen_doc.nbtest._submodule_name": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 7,
"test": "test_submodule_name"
}
],
"fastai.gen_doc.nbtest.direct_test_match": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 38,
"test": "test_direct_test_match"
},
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 46,
"test": "test_direct_test_match_class_methods"
}
],
"fastai.gen_doc.nbtest.fuzzy_test_match": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 38,
"test": "test_fuzzy_test_match"
}
],
"fastai.gen_doc.nbtest.get_file": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 26,
"test": "test_wrapped_functions"
}
],
"fastai.gen_doc.nbtest.get_tests_dir": [
{
"file": "tests/test_gen_doc_nbtest.py",
"line": 70,
"test": "test_get_tests_dir"
}
],
"fastai.layers.SelfAttention": [
{
"file": "tests/test_torch_core.py",
"line": 269,
"test": "test_keep_parameter"
}
],
"fastai.metrics.accuracy": [
{
"file": "tests/test_metrics.py",
"line": 44,
"test": "test_accuracy"
},
{
"file": "tests/test_vision_train.py",
"line": 41,
"test": "test_accuracy"
}
],
"fastai.metrics.accuracy_thresh": [
{
"file": "tests/test_metrics.py",
"line": 99,
"test": "test_accuracy_thresh"
}
],
"fastai.metrics.dice": [
{
"file": "tests/test_metrics.py",
"line": 108,
"test": "test_dice"
},
{
"file": "tests/test_metrics.py",
"line": 118,
"test": "test_dice_iou"
}
],
"fastai.metrics.error_rate": [
{
"file": "tests/test_metrics.py",
"line": 86,
"test": "test_error_rate"
},
{
"file": "tests/test_vision_train.py",
"line": 45,
"test": "test_error_rate"
}
],
"fastai.metrics.exp_rmspe": [
{
"file": "tests/test_metrics.py",
"line": 90,
"test": "test_exp_rmspe"
},
{
"file": "tests/test_metrics.py",
"line": 94,
"test": "test_exp_rmspe_num_of_ele"
}
],
"fastai.metrics.explained_variance": [
{
"file": "tests/test_metrics.py",
"line": 174,
"test": "test_explained_variance"
}
],
"fastai.metrics.fbeta": [
{
"file": "tests/test_metrics.py",
"line": 126,
"test": "test_fbeta"
}
],
"fastai.metrics.foreground_acc": [
{
"file": "tests/test_metrics.py",
"line": 78,
"test": "test_foreground_acc"
}
],
"fastai.metrics.mean_absolute_error": [
{
"file": "tests/test_metrics.py",
"line": 135,
"test": "test_mae"
}
],
"fastai.metrics.mean_squared_error": [
{
"file": "tests/test_metrics.py",
"line": 144,
"test": "test_mse"
}
],
"fastai.metrics.mean_squared_logarithmic_error": [
{
"file": "tests/test_metrics.py",
"line": 163,
"test": "test_msle"
}
],
"fastai.metrics.r2_score": [
{
"file": "tests/test_metrics.py",
"line": 185,
"test": "test_r2_score"
}
],
"fastai.metrics.root_mean_squared_error": [
{
"file": "tests/test_metrics.py",
"line": 153,
"test": "test_rmse"
}
],
"fastai.metrics.top_k_accuracy": [
{
"file": "tests/test_metrics.py",
"line": 69,
"test": "test_top_k_accuracy"
}
],
"fastai.tabular.data.TabularList.from_df": [
{
"file": "tests/test_tabular_data.py",
"line": 5,
"test": "test_from_df"
}
],
"fastai.tabular.models._cl_int_from_learner": [
{
"file": "tests/test_vision_train.py",
"line": 72,
"test": "test_interp"
}
],
"fastai.tabular.models._learner_interpret": [
{
"file": "tests/test_vision_train.py",
"line": 78,
"test": "test_interp_shortcut"
}
],
"fastai.tabular.transform.Categorify": [
{
"file": "tests/test_tabular_transform.py",
"line": 6,
"test": "test_categorify"
}
],
"fastai.tabular.transform.FillMissing": [
{
"file": "tests/test_tabular_transform.py",
"line": 30,
"test": "test_default_fill_strategy_is_median"
}
],
"fastai.tabular.transform.FillMissing.apply_test": [
{
"file": "tests/test_tabular_transform.py",
"line": 36,
"test": "test_fill_missing_leaves_no_na_values"
},
{
"file": "tests/test_tabular_transform.py",
"line": 49,
"test": "test_fill_missing_returns_correct_medians"
}
],
"fastai.tabular.transform.FillMissing.apply_train": [
{
"file": "tests/test_tabular_transform.py",
"line": 36,
"test": "test_fill_missing_leaves_no_na_values"
},
{
"file": "tests/test_tabular_transform.py",
"line": 49,
"test": "test_fill_missing_returns_correct_medians"
}
],
"fastai.tabular.transform.cont_cat_split": [
{
"file": "tests/test_tabular_transform.py",
"line": 64,
"test": "test_cont_cat_split"
}
],
"fastai.text.data.SortSampler": [
{
"file": "tests/test_text_data.py",
"line": 158,
"test": "test_sampler"
},
{
"file": "tests/test_text_data.py",
"line": 158,
"test": "test_sort_sampler"
}
],
"fastai.text.data.SortishSampler": [
{
"file": "tests/test_text_data.py",
"line": 143,
"test": "test_sortish_sampler"
}
],
"fastai.text.data.TextDataBunch.from_csv": [
{
"file": "tests/test_text_data.py",
"line": 57,
"test": "test_from_csv_and_from_df"
}
],
"fastai.text.data.TextDataBunch.from_df": [
{
"file": "tests/test_text_data.py",
"line": 57,
"test": "test_from_csv_and_from_df"
}
],
"fastai.text.data.TextDataBunch.from_ids": [
{
"file": "tests/test_text_data.py",
"line": 173,
"test": "test_from_ids_works_for_equally_length_sentences"
},
{
"file": "tests/test_text_data.py",
"line": 181,
"test": "test_from_ids_works_for_variable_length_sentences"
},
{
"file": "tests/test_text_data.py",
"line": 189,
"test": "test_from_ids_exports_classes"
}
],
"fastai.text.learner.language_model_learner": [
{
"file": "tests/test_text_train.py",
"line": 61,
"test": "test_qrnn_works_with_no_split"
},
{
"file": "tests/test_text_train.py",
"line": 73,
"test": "test_qrnn_works_if_split_fn_provided"
}
],
"fastai.text.learner.text_classifier_learner": [
{
"file": "tests/test_text_train.py",
"line": 100,
"test": "test_classifier"
},
{
"file": "tests/test_text_train.py",
"line": 139,
"test": "test_order_preds"
}
],
"fastai.text.models.qrnn.BwdForgetMultGPU": [
{
"file": "tests/test_text_qrnn.py",
"line": 28,
"test": "test_forget_mult_cuda"
}
],
"fastai.text.models.qrnn.ForgetMultGPU": [
{
"file": "tests/test_text_qrnn.py",
"line": 7,
"test": "test_forget_mult_forward_gpu"
},
{
"file": "tests/test_text_qrnn.py",
"line": 27,
"test": "test_compare_forget_mult_forward_implementations"
},
{
"file": "tests/test_text_qrnn.py",
"line": 28,
"test": "test_forget_mult_cuda"
}
],
"fastai.text.models.qrnn.QRNN": [
{
"file": "tests/test_text_qrnn.py",
"line": 105,
"test": "test_qrnn_bidir"
}
],
"fastai.text.models.qrnn.QRNNLayer": [
{
"file": "tests/test_text_qrnn.py",
"line": 89,
"test": "test_qrnn_layer"
}
],
"fastai.text.models.qrnn.forget_mult_CPU": [
{
"file": "tests/test_text_qrnn.py",
"line": 75,
"test": "test_forget_mult"
}
],
"fastai.text.transform.Tokenizer": [
{
"file": "tests/test_text_transform.py",
"line": 15,
"test": "test_tokenize"
},
{
"file": "tests/test_text_transform.py",
"line": 24,
"test": "test_tokenize_handles_empty_lines"
},
{
"file": "tests/test_text_transform.py",
"line": 32,
"test": "test_tokenize_ignores_extraneous_space"
}
],
"fastai.text.transform.Vocab.numericalize": [
{
"file": "tests/test_text_transform.py",
"line": 39,
"test": "test_numericalize_and_textify"
}
],
"fastai.text.transform.Vocab.textify": [
{
"file": "tests/test_text_transform.py",
"line": 39,
"test": "test_numericalize_and_textify"
}
],
"fastai.text.transform.deal_caps": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.text.transform.fix_html": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.text.transform.replace_all_caps": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.text.transform.replace_rep": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.text.transform.replace_wrep": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.text.transform.rm_useless_spaces": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.text.transform.spec_add_spaces": [
{
"file": "tests/test_text_transform.py",
"line": 5,
"test": "test_rules"
}
],
"fastai.torch_core.NoneReduceOnCPU": [
{
"file": "tests/test_torch_core.py",
"line": 249,
"test": "test_none_reduce_on_cpu"
}
],
"fastai.torch_core.apply_init": [
{
"file": "tests/test_torch_core.py",
"line": 47,
"test": "test_apply_init"
}
],
"fastai.torch_core.apply_leaf": [
{
"file": "tests/test_torch_core.py",
"line": 47,
"test": "test_apply_init"
}
],
"fastai.torch_core.batch_to_half": [
{
"file": "tests/test_fp16.py",
"line": 32,
"test": "test_batch_to_half"
}
],
"fastai.torch_core.children": [
{
"file": "tests/test_torch_core.py",
"line": 197,
"test": "test_children"
}
],
"fastai.torch_core.first_layer": [
{
"file": "tests/test_torch_core.py",
"line": 213,
"test": "test_first_layer"
}
],
"fastai.torch_core.in_channels": [
{
"file": "tests/test_torch_core.py",
"line": 59,
"test": "test_in_channels"
},
{
"file": "tests/test_torch_core.py",
"line": 64,
"test": "test_in_channels_no_weights"
}
],
"fastai.torch_core.last_layer": [
{
"file": "tests/test_torch_core.py",
"line": 220,
"test": "test_last_layer"
}
],
"fastai.torch_core.model2half": [
{
"file": "tests/test_fp16.py",
"line": 6,
"test": "test_model2half"
},
{
"file": "tests/test_fp16.py",
"line": 16,
"test": "test_model2half_forward"
}
],
"fastai.torch_core.model_type": [
{
"file": "tests/test_torch_core.py",
"line": 227,
"test": "test_model_type"
}
],
"fastai.torch_core.np2model_tensor": [
{
"file": "tests/test_torch_core.py",
"line": 94,
"test": "test_np2model_tensor"
}
],
"fastai.torch_core.np_address": [
{
"file": "tests/test_torch_core.py",
"line": 100,
"test": "test_np_address"
}
],
"fastai.torch_core.num_children": [
{
"file": "tests/test_torch_core.py",
"line": 206,
"test": "test_num_children"
}
],
"fastai.torch_core.range_children": [
{
"file": "tests/test_torch_core.py",
"line": 70,
"test": "test_range_children"
}
],
"fastai.torch_core.requires_grad": [
{
"file": "tests/test_torch_core.py",
"line": 32,
"test": "test_requires_grad"
},
{
"file": "tests/test_torch_core.py",
"line": 37,
"test": "test_requires_grad_set"
}
],
"fastai.torch_core.set_bn_eval": [
{
"file": "tests/test_torch_core.py",
"line": 87,
"test": "test_set_bn_eval"
}
],
"fastai.torch_core.split_model": [
{
"file": "tests/test_torch_core.py",
"line": 75,
"test": "test_split_model"
}
],
"fastai.torch_core.split_no_wd_params": [
{
"file": "tests/test_torch_core.py",
"line": 81,
"test": "test_split_no_wd_params"
}
],
"fastai.torch_core.tensor": [
{
"file": "tests/test_torch_core.py",
"line": 13,
"test": "test_tensor_with_list"
},
{
"file": "tests/test_torch_core.py",
"line": 18,
"test": "test_tensor_with_ndarray"
},
{
"file": "tests/test_torch_core.py",
"line": 25,
"test": "test_tensor_with_tensor"
}
],
"fastai.torch_core.to_cpu": [
{
"file": "tests/test_torch_core.py",
"line": 154,
"test": "test_to_cpu"
}
],
"fastai.torch_core.to_data": [
{
"file": "tests/test_torch_core.py",
"line": 106,
"test": "test_to_data"
}
],
"fastai.torch_core.to_detach": [
{
"file": "tests/test_torch_core.py",
"line": 131,
"test": "test_to_detach"
}
],
"fastai.torch_core.to_float": [
{
"file": "tests/test_torch_core.py",
"line": 184,
"test": "test_to_float"
}
],
"fastai.torch_core.to_half": [
{
"file": "tests/test_fp16.py",
"line": 25,
"test": "test_to_half"
},
{
"file": "tests/test_torch_core.py",
"line": 171,
"test": "test_to_half"
}
],
"fastai.torch_core.to_np": [
{
"file": "tests/test_torch_core.py",
"line": 244,
"test": "test_to_np"
}
],
"fastai.torch_core.trange_of": [
{
"file": "tests/test_torch_core.py",
"line": 236,
"test": "test_trange_of"
}
],
"fastai.train.ClassificationInterpretation": [
{
"file": "tests/test_vision_train.py",
"line": 95,
"test": "test_ClassificationInterpretation"
}
],
"fastai.train.ClassificationInterpretation.confusion_matrix": [
{
"file": "tests/test_tabular_train.py",
"line": 84,
"test": "test_confusion_tabular"
}
],
"fastai.train.fit_one_cycle": [
{
"file": "tests/test_train.py",
"line": 36,
"test": "test_fit_one_cycle"
}
],
"fastai.train.lr_find": [
{
"file": "tests/test_train.py",
"line": 16,
"test": "test_lr_find"
},
{
"file": "tests/test_vision_train.py",
"line": 84,
"test": "test_lrfind"
}
],
"fastai.utils.collect_env.check_perf": [
{
"file": "tests/test_utils.py",
"line": 18,
"test": "test_check_perf"
}
],
"fastai.utils.collect_env.show_install": [
{
"file": "tests/test_utils.py",
"line": 8,
"test": "test_show_install"
}
],
"fastai.utils.mem.GPUMemTrace": [
{
"file": "tests/test_utils_mem.py",
"line": 76,
"test": "test_gpu_mem_trace"
},
{
"file": "tests/test_utils_mem.py",
"line": 137,
"test": "test_gpu_mem_trace_ctx"
}
],
"fastai.utils.mem.gpu_mem_get": [
{
"file": "tests/test_utils_mem.py",
"line": 25,
"test": "test_gpu_mem_by_id"
}
],
"fastai.utils.mem.gpu_mem_get_all": [
{
"file": "tests/test_utils_mem.py",
"line": 35,
"test": "test_gpu_mem_all"
}
],
"fastai.utils.mem.gpu_mem_get_used": [
{
"file": "tests/test_utils_mem.py",
"line": 56,
"test": "test_gpu_mem_measure_consumed_reclaimed"
}
],
"fastai.utils.mem.gpu_mem_trace": [
{
"file": "tests/test_utils_mem.py",
"line": 178,
"test": "test_gpu_mem_trace_decorator"
}
],
"fastai.utils.mem.gpu_with_max_free_mem": [
{
"file": "tests/test_utils_mem.py",
"line": 44,
"test": "test_gpu_with_max_free_mem"
}
],
"fastai.utils.mod_display.progress_disabled_ctx": [
{
"file": "tests/test_mod_display.py",
"line": 16,
"test": "test_progress_disabled_ctx"
}
],
"fastai.vision.data.ImageDataBunch.from_csv": [
{
"file": "tests/test_vision_data.py",
"line": 22,
"test": "test_path_can_be_str_type"
},
{
"file": "tests/test_vision_data.py",
"line": 54,
"test": "test_from_csv_and_from_df"
}
],
"fastai.vision.data.ImageDataBunch.from_df": [
{
"file": "tests/test_vision_data.py",
"line": 54,
"test": "test_from_csv_and_from_df"
}
],
"fastai.vision.data.ImageDataBunch.from_folder": [
{
"file": "tests/test_vision_data.py",
"line": 26,
"test": "test_from_folder"
}
],
"fastai.vision.data.ImageDataBunch.from_lists": [
{
"file": "tests/test_vision_data.py",
"line": 39,
"test": "test_from_lists"
}
],
"fastai.vision.data.ImageDataBunch.from_name_re": [
{
"file": "tests/test_vision_data.py",
"line": 32,
"test": "test_from_name_re"
},
{
"file": "tests/test_vision_data.py",
"line": 70,
"test": "test_image_resize"
}
],
"fastai.vision.data.ImageDataBunch.normalize": [
{
"file": "tests/test_vision_data.py",
"line": 120,
"test": "test_normalize"
}
],
"fastai.vision.data.ImageList.from_csv": [
{
"file": "tests/test_vision_data.py",
"line": 227,
"test": "test_multi"
}
],
"fastai.vision.data.ImageList.from_folder": [
{
"file": "tests/test_vision_data.py",
"line": 217,
"test": "test_vision_datasets"
},
{
"file": "tests/test_vision_gan.py",
"line": 30,
"test": "test_gan_datasets"
}
],
"fastai.vision.data.ObjectItemList": [
{
"file": "tests/test_vision_data.py",
"line": 267,
"test": "test_coco"
},
{
"file": "tests/test_vision_data.py",
"line": 280,
"test": "test_coco_same_size"
},
{
"file": "tests/test_vision_data.py",
"line": 297,
"test": "test_coco_pickle"
}
],
"fastai.vision.data.PointsItemList": [
{
"file": "tests/test_vision_data.py",
"line": 254,
"test": "test_points"
}
],
"fastai.vision.data.SegmentationItemList": [
{
"file": "tests/test_vision_data.py",
"line": 238,
"test": "test_camvid"
}
],
"fastai.vision.data.denormalize": [
{
"file": "tests/test_vision_data.py",
"line": 134,
"test": "test_denormalize"
}
],
"fastai.vision.data.download_images": [
{
"file": "tests/test_vision_data.py",
"line": 144,
"test": "test_download_images"
}
],
"fastai.vision.data.verify_image": [
{
"file": "tests/test_vision_data.py",
"line": 201,
"test": "test_verify_image"
}
],
"fastai.vision.data.verify_images": [
{
"file": "tests/test_vision_data.py",
"line": 190,
"test": "test_verify_images"
}
],
"fastai.vision.gan.GANModule": [
{
"file": "tests/test_vision_gan.py",
"line": 67,
"test": "test_gan_module"
}
],
"fastai.vision.gan.GANTrainer": [
{
"file": "tests/test_vision_gan.py",
"line": 80,
"test": "test_gan_trainer"
}
],
"fastai.vision.gan.NoisyItem": [
{
"file": "tests/test_vision_gan.py",
"line": 37,
"test": "test_noisy_item"
}
],
"fastai.vision.gan.basic_critic": [
{
"file": "tests/test_vision_gan.py",
"line": 56,
"test": "test_basic_critic"
}
],
"fastai.vision.gan.basic_generator": [
{
"file": "tests/test_vision_gan.py",
"line": 46,
"test": "test_basic_generator"
}
],
"fastai.vision.image.Image": [
{
"file": "tests/test_vision_transform.py",
"line": 58,
"test": "test_mask_data_aug"
}
],
"fastai.vision.image.Image.resize": [
{
"file": "tests/test_vision_image.py",
"line": 49,
"test": "test_image_resize_same_size_shortcut"
}
],
"fastai.vision.image.ImageBBox": [
{
"file": "tests/test_vision_transform.py",
"line": 37,
"test": "test_bbox_data_aug"
}
],
"fastai.vision.image.ImagePoints": [
{
"file": "tests/test_vision_transform.py",
"line": 22,
"test": "test_points_data_aug"
}
],
"fastai.vision.image.ImageSegment": [
{
"file": "tests/test_vision_transform.py",
"line": 58,
"test": "test_mask_data_aug"
}
],
"fastai.vision.image.pil2tensor": [
{
"file": "tests/test_vision_data.py",
"line": 348,
"test": "test_vision_pil2tensor"
},
{
"file": "tests/test_vision_data.py",
"line": 379,
"test": "test_vision_pil2tensor_16bit"
},
{
"file": "tests/test_vision_data.py",
"line": 386,
"test": "test_vision_pil2tensor_numpy"
}
],
"fastai.vision.image.rle_decode": [
{
"file": "tests/test_vision_image.py",
"line": 17,
"test": "test_rle_decode_with_str"
},
{
"file": "tests/test_vision_image.py",
"line": 23,
"test": "test_rle_decode_empty_str"
}
],
"fastai.vision.image.rle_encode": [
{
"file": "tests/test_vision_image.py",
"line": 5,
"test": "test_rle_encode_with_array"
},
{
"file": "tests/test_vision_image.py",
"line": 11,
"test": "test_rle_encode_all_zero_array"
}
],
"fastai.vision.image.tis2hw": [
{
"file": "tests/test_vision_image.py",
"line": 29,
"test": "test_tis2hw_int"
},
{
"file": "tests/test_vision_image.py",
"line": 34,
"test": "test_tis2hw_3dims"
},
{
"file": "tests/test_vision_image.py",
"line": 39,
"test": "test_tis2hw_2dims"
},
{
"file": "tests/test_vision_image.py",
"line": 44,
"test": "test_tis2hw_str_raises_an_error"
}
],
"fastai.vision.learner._cl_int_from_learner": [
{
"file": "tests/test_vision_train.py",
"line": 72,
"test": "test_interp"
}
],
"fastai.vision.learner._learner_interpret": [
{
"file": "tests/test_vision_train.py",
"line": 78,
"test": "test_interp_shortcut"
}
],
"fastai.vision.learner.create_body": [
{
"file": "tests/test_vision_learner.py",
"line": 16,
"test": "test_create_body"
}
],
"fastai.vision.learner.create_head": [
{
"file": "tests/test_vision_learner.py",
"line": 39,
"test": "test_create_head"
}
],
"fastai.vision.learner.has_pool_type": [
{
"file": "tests/test_vision_learner.py",
"line": 45,
"test": "test_has_pool_type"
}
],
"fastai.vision.models.unet.DynamicUnet": [
{
"file": "tests/test_vision_models_unet.py",
"line": 39,
"test": "test_dynamic_unet_shape"
},
{
"file": "tests/test_vision_models_unet.py",
"line": 45,
"test": "test_unet_block_shapes"
}
],
"fastai.vision.transform._crop": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
},
{
"file": "tests/test_vision_transform.py",
"line": 123,
"test": "test_crop_without_size"
},
{
"file": "tests/test_vision_transform.py",
"line": 131,
"test": "test_crops_with_tensor_image_sizes"
}
],
"fastai.vision.transform._dihedral": [
{
"file": "tests/test_vision_transform.py",
"line": 102,
"test": "test_all_dihedral"
}
],
"fastai.vision.transform._flip_affine": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
}
],
"fastai.vision.transform._flip_lr": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
}
],
"fastai.vision.transform._pad": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
}
],
"fastai.vision.transform._perspective_warp": [
{
"file": "tests/test_vision_transform.py",
"line": 83,
"test": "test_all_warps"
}
],
"fastai.vision.transform._rotate": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
}
],
"fastai.vision.transform._skew": [
{
"file": "tests/test_vision_transform.py",
"line": 83,
"test": "test_all_warps"
}
],
"fastai.vision.transform._squish": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
}
],
"fastai.vision.transform._tilt": [
{
"file": "tests/test_vision_transform.py",
"line": 83,
"test": "test_all_warps"
}
],
"fastai.vision.transform._zoom": [
{
"file": "tests/test_vision_transform.py",
"line": 111,
"test": "test_deterministic_transforms"
}
],
"fastai.vision.transform.get_transforms": [
{
"file": "tests/test_vision_data.py",
"line": 313,
"test": "test_image_to_image_different_y_size"
},
{
"file": "tests/test_vision_data.py",
"line": 328,
"test": "test_image_to_image_different_tfms"
}
],
"fastai.widgets.image_cleaner.ImageCleaner": [
{
"file": "tests/test_widgets_image_cleaner.py",
"line": 16,
"test": "test_image_cleaner_index_length_mismatch"
},
{
"file": "tests/test_widgets_image_cleaner.py",
"line": 23,
"test": "test_image_cleaner_length_correct"
},
{
"file": "tests/test_widgets_image_cleaner.py",
"line": 30,
"test": "test_image_cleaner_wrong_input_type"
}
],
"fastai.widgets.image_downloader.ImageDownloader": [
{
"file": "tests/test_widgets_image_cleaner.py",
"line": 36,
"test": "test_image_downloader_with_path"
}
]
}
================================================
FILE: fastai/text/__init__.py
================================================
from .. import basics
from ..basics import *
from .learner import *
from .data import *
from .transform import *
from .models import *
from .. import text
__all__ = [*basics.__all__, *learner.__all__, *data.__all__, *transform.__all__, *models.__all__, 'text']
================================================
FILE: fastai/text/data.py
================================================
"NLP data loading pipeline. Supports csv, folders, and preprocessed data."
from ..torch_core import *
from .transform import *
from ..basic_data import *
from ..data_block import *
from ..layers import *
from ..callback import Callback
__all__ = ['LanguageModelPreLoader', 'SortSampler', 'SortishSampler', 'TextList', 'pad_collate', 'TextDataBunch',
'TextLMDataBunch', 'TextClasDataBunch', 'Text', 'open_text', 'TokenizeProcessor', 'NumericalizeProcessor',
'OpenFileProcessor', 'LMLabelList', 'LMTextList', 'SPProcessor']
TextMtd = IntEnum('TextMtd', 'DF TOK IDS')
text_extensions = {'.txt'}
class LanguageModelPreLoader(Callback):
"Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling."
class CircularIndex():
"Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed"
def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward
def __getitem__(self, i):
return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)]
def __len__(self) -> int: return len(self.idx)
def shuffle(self): np.random.shuffle(self.idx)
def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False,
shuffle:bool=False):
self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths
self.bs *= num_distrib() or 1
self.totalToks,self.ite_len,self.idx = int(0),None,None
def __len__(self):
if self.ite_len is None:
if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x.items])
self.totalToks = self.lengths.sum()
self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1
return self.ite_len
def __getattr__(self,k:str)->Any: return getattr(self.dataset, k)
def allocate_buffers(self):
"Create the ragged array that will be filled when we ask for items."
if self.ite_len is None: len(self)
self.idx = LanguageModelPreLoader.CircularIndex(len(self.dataset.x.items), not self.backwards)
self.batch = np.zeros((self.bs, self.bptt+1), dtype=np.int64)
self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,1:self.bptt+1]
#ro: index of the text we're at inside our datasets for the various batches
self.ro = np.zeros(self.bs, dtype=np.int64)
#ri: index of the token we're at inside our current text for the various batches
self.ri = np.zeros(self.bs, dtype=np.int)
def on_epoch_begin(self, **kwargs):
if self.idx is None or len(self.idx) != len(self.dataset.x.items): self.allocate_buffers()
elif self.shuffle: self.idx.shuffle()
self.idx.forward = not self.backwards
step = self.totalToks / self.bs
ln_rag, countTokens, i_rag = 0, 0, -1
for i in range(0,self.bs):
#Compute the initial values for ro and ri
while ln_rag + countTokens <= int(step * i):
countTokens += ln_rag
i_rag += 1
ln_rag = self.lengths[self.idx[i_rag]]
self.ro[i] = i_rag
self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens)
#Training dl gets on_epoch_begin called, val_dl, on_epoch_end
def on_epoch_end(self, **kwargs): self.on_epoch_begin()
def __getitem__(self, k:int):
j = k % self.bs
if self.item is not None: return self.dataset[0]
if self.idx is None: self.on_epoch_begin()
self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x.items, self.idx, self.batch[j],
self.ro[j], self.ri[j], overlap=1, lengths=self.lengths)
return self.batch_x[j], self.batch_y[j]
def fill_row(self, forward, items, idx, row, ro, ri, overlap,lengths):
"Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented"
ibuf = n = 0
ro -= 1
while ibuf < row.size:
ro += 1
ix = idx[ro]
rag = items[ix]
if forward:
ri = 0 if ibuf else ri
n = min(lengths[ix] - ri, row.size - ibuf)
row[ibuf:ibuf+n] = rag[ri:ri+n]
else:
ri = lengths[ix] if ibuf else ri
n = min(ri, row.size - ibuf)
row[ibuf:ibuf+n] = rag[ri-n:ri][::-1]
ibuf += n
return ro, ri + ((n-overlap) if forward else -(n-overlap))
class SortSampler(Sampler):
"Go through the text data by order of length."
def __init__(self, data_source:NPArrayList, key:KeyFunc): self.data_source,self.key = data_source,key
def __len__(self) -> int: return len(self.data_source)
def __iter__(self):
return iter(sorted(range_of(self.data_source), key=self.key, reverse=True))
class SortishSampler(Sampler):
"Go through the text data by order of length with a bit of randomness."
def __init__(self, data_source:NPArrayList, key:KeyFunc, bs:int):
self.data_source,self.key,self.bs = data_source,key,bs
def __len__(self) -> int: return len(self.data_source)
def __iter__(self):
idxs = np.random.permutation(len(self.data_source))
sz = self.bs*50
ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
sz = self.bs
ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([],dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
def pad_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:
"Function that collect samples and adds padding. Flips token order if needed"
samples = to_data(samples)
max_len = max([len(s[0]) for s in samples])
res = torch.zeros(len(samples), max_len).long() + pad_idx
if backwards: pad_first = not pad_first
for i,s in enumerate(samples):
if pad_first: res[i,-len(s[0]):] = LongTensor(s[0])
else: res[i,:len(s[0]):] = LongTensor(s[0])
if backwards: res = res.flip(1)
return res, tensor(np.array([s[1] for s in samples]))
def _get_processor(tokenizer:Tokenizer=None, vocab:Vocab=None, chunksize:int=10000, max_vocab:int=60000,
min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):
return [TokenizeProcessor(tokenizer=tokenizer, chunksize=chunksize,
mark_fields=mark_fields, include_bos=include_bos, include_eos=include_eos),
NumericalizeProcessor(vocab=vocab, max_vocab=max_vocab, min_freq=min_freq)]
class TextDataBunch(DataBunch):
"General class to get a `DataBunch` for NLP. Subclassed by `TextLMDataBunch` and `TextClasDataBunch`."
@classmethod
def from_ids(cls, path:PathOrStr, vocab:Vocab, train_ids:Collection[Collection[int]], valid_ids:Collection[Collection[int]],
test_ids:Collection[Collection[int]]=None, train_lbls:Collection[Union[int,float]]=None,
valid_lbls:Collection[Union[int,float]]=None, classes:Collection[Any]=None,
processor:PreProcessor=None, **kwargs) -> DataBunch:
"Create a `TextDataBunch` from ids, labels and a `vocab`. `kwargs` are passed to the dataloader creation."
src = ItemLists(path, TextList(train_ids, vocab, path=path, processor=[]),
TextList(valid_ids, vocab, path=path, processor=[]))
src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_lists(train_lbls, valid_lbls, classes=classes, processor=[])
if not is1d(train_lbls): src.train.y.one_hot,src.valid.y.one_hot = True,True
if test_ids is not None: src.add_test(TextList(test_ids, vocab, path=path), label=train_lbls[0])
src.valid.x.processor = ifnone(processor, [TokenizeProcessor(), NumericalizeProcessor(vocab=vocab)])
if classes is not None: src.valid.y.processor = ifnone(processor, [CategoryProcessor(src.valid.y)])
return src.databunch(**kwargs)
@classmethod
def load(cls, path:PathOrStr, cache_name:PathOrStr='tmp', processor:PreProcessor=None, **kwargs):
"Load a `TextDataBunch` from `path/cache_name`. `kwargs` are passed to the dataloader creation."
warn("""This method is deprecated and only kept to load data serialized in v1.0.43 or earlier.
Use `load_data` for data saved with v1.0.44 or later.""", DeprecationWarning)
cache_path = Path(path)/cache_name
vocab = Vocab(pickle.load(open(cache_path/'itos.pkl','rb')))
train_ids,train_lbls = np.load(cache_path/f'train_ids.npy'), np.load(cache_path/f'train_lbl.npy')
valid_ids,valid_lbls = np.load(cache_path/f'valid_ids.npy'), np.load(cache_path/f'valid_lbl.npy')
test_ids = np.load(cache_path/f'test_ids.npy') if os.path.isfile(cache_path/f'test_ids.npy') else None
classes = loadtxt_str(cache_path/'classes.txt') if os.path.isfile(cache_path/'classes.txt') else None
return cls.from_ids(path, vocab, train_ids, valid_ids, test_ids, train_lbls, valid_lbls, classes, processor, **kwargs)
@classmethod#TODO: test
def from_tokens(cls, path:PathOrStr, trn_tok:Collection[Collection[str]], trn_lbls:Collection[Union[int,float]],
val_tok:Collection[Collection[str]], val_lbls:Collection[Union[int,float]], vocab:Vocab=None,
tst_tok:Collection[Collection[str]]=None, classes:Collection[Any]=None, max_vocab:int=60000, min_freq:int=3,
**kwargs) -> DataBunch:
"Create a `TextDataBunch` from tokens and labels. `kwargs` are passed to the dataloader creation."
processor = NumericalizeProcessor(vocab=vocab, max_vocab=max_vocab, min_freq=min_freq)
src = ItemLists(path, TextList(trn_tok, path=path, processor=processor),
TextList(val_tok, path=path, processor=processor))
src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_lists(trn_lbls, val_lbls, classes=classes)
if tst_tok is not None: src.add_test(TextList(tst_tok, path=path))
return src.databunch(**kwargs)
@classmethod
def from_df(cls, path:PathOrStr, train_df:DataFrame, valid_df:DataFrame, test_df:Optional[DataFrame]=None,
tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, text_cols:IntsOrStrs=1,
label_cols:IntsOrStrs=0, label_delim:str=None, chunksize:int=10000, max_vocab:int=60000,
min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs) -> DataBunch:
"Create a `TextDataBunch` from DataFrames. `kwargs` are passed to the dataloader creation."
processor = _get_processor(tokenizer=tokenizer, vocab=vocab, chunksize=chunksize, max_vocab=max_vocab,
min_freq=min_freq, mark_fields=mark_fields,
include_bos=include_bos, include_eos=include_eos)
if classes is None and is_listy(label_cols) and len(label_cols) > 1: classes = label_cols
src = ItemLists(path, TextList.from_df(train_df, path, cols=text_cols, processor=processor),
TextList.from_df(valid_df, path, cols=text_cols, processor=processor))
if cls==TextLMDataBunch: src = src.label_for_lm()
else:
if label_delim is not None: src = src.label_from_df(cols=label_cols, classes=classes, label_delim=label_delim)
else: src = src.label_from_df(cols=label_cols, classes=classes)
if test_df is not None: src.add_test(TextList.from_df(test_df, path, cols=text_cols))
return src.databunch(**kwargs)
@classmethod
def from_csv(cls, path:PathOrStr, csv_name, valid_pct:float=0.2, test:Optional[str]=None,
tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, delimiter:str=None, header='infer',
text_cols:IntsOrStrs=1, label_cols:IntsOrStrs=0, label_delim:str=None,
chunksize:int=10000, max_vocab:int=60000, min_freq:int=2,
mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs) -> DataBunch:
"Create a `TextDataBunch` from texts in csv files. `kwargs` are passed to the dataloader creation."
df = pd.read_csv(Path(path)/csv_name, header=header, delimiter=delimiter)
df = df.iloc[np.random.permutation(len(df))]
cut = int(valid_pct * len(df)) + 1
train_df, valid_df = df[cut:], df[:cut]
test_df = None if test is None else pd.read_csv(Path(path)/test, header=header, delimiter=delimiter)
return cls.from_df(path, train_df, valid_df, test_df, tokenizer=tokenizer, vocab=vocab, classes=classes, text_cols=text_cols,
label_cols=label_cols, label_delim=label_delim, chunksize=chunksize, max_vocab=max_vocab,
min_freq=min_freq, mark_fields=mark_fields,
include_bos=include_bos, include_eos=include_eos, **kwargs)
@classmethod
def from_folder(cls, path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None,
classes:Collection[Any]=None, tokenizer:Tokenizer=None, vocab:Vocab=None, chunksize:int=10000, max_vocab:int=60000,
min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs):
"Create a `TextDataBunch` from text files in folders."
path = Path(path).absolute()
processor = [OpenFileProcessor()] + _get_processor(tokenizer=tokenizer, vocab=vocab, chunksize=chunksize, max_vocab=max_vocab,
min_freq=min_freq, mark_fields=mark_fields, include_bos=include_bos, include_eos=include_eos)
src = (TextList.from_folder(path, processor=processor)
.split_by_folder(train=train, valid=valid))
src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_folder(classes=classes)
if test is not None: src.add_test_folder(path/test)
return src.databunch(**kwargs)
class TextLMDataBunch(TextDataBunch):
"Create a `TextDataBunch` suitable for training a language model."
@classmethod
def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None,
num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate,
dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70, backwards:bool=False, **dl_kwargs) -> DataBunch:
"Create a `TextDataBunch` in `path` from the `datasets` for language modelling. Passes `**dl_kwargs` on to `DataLoader()`"
datasets = cls._init_ds(train_ds, valid_ds, test_ds)
val_bs = ifnone(val_bs, bs)
datasets = [LanguageModelPreLoader(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, backwards=backwards)
for i,ds in enumerate(datasets)]
val_bs = bs
dls = [DataLoader(d, b, shuffle=False, **dl_kwargs) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None]
return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
class TextClasDataBunch(TextDataBunch):
"Create a `TextDataBunch` suitable for training an RNN classifier."
@classmethod
def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,
pad_first=True, device:torch.device=None, no_check:bool=False, backwards:bool=False,
dl_tfms:Optional[Collection[Callable]]=None, **dl_kwargs) -> DataBunch:
"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`"
datasets = cls._init_ds(train_ds, valid_ds, test_ds)
val_bs = ifnone(val_bs, bs)
collate_fn = partial(pad_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)
train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs)
train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)
dataloaders = [train_dl]
for ds in datasets[1:]:
lengths = [len(t) for t in ds.x.items]
sampler = SortSampler(ds.x, key=lengths.__getitem__)
dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))
return cls(*dataloaders, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
def open_text(fn:PathOrStr, enc='utf-8'):
"Read the text in `fn`."
with open(fn,'r', encoding = enc) as f: return ''.join(f.readlines())
class Text(ItemBase):
"Basic item for text data in numericalized `ids`."
def __init__(self, ids, text): self.data,self.text = np.array(ids, dtype=np.int64),text
def __str__(self): return str(self.text)
class TokenizeProcessor(PreProcessor):
"`PreProcessor` that tokenizes the texts in `ds`."
def __init__(self, ds:ItemList=None, tokenizer:Tokenizer=None, chunksize:int=10000,
mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):
self.tokenizer,self.chunksize,self.mark_fields = ifnone(tokenizer, Tokenizer()),chunksize,mark_fields
self.include_bos, self.include_eos = include_bos, include_eos
def process_one(self, item):
return self.tokenizer._process_all_1(_join_texts([item], self.mark_fields, self.include_bos, self.include_eos))[0]
def process(self, ds):
ds.items = _join_texts(ds.items, self.mark_fields, self.include_bos, self.include_eos)
tokens = []
for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
tokens += self.tokenizer.process_all(ds.items[i:i+self.chunksize])
ds.items = tokens
class NumericalizeProcessor(PreProcessor):
"`PreProcessor` that numericalizes the tokens in `ds`."
def __init__(self, ds:ItemList=None, vocab:Vocab=None, max_vocab:int=60000, min_freq:int=3):
vocab = ifnone(vocab, ds.vocab if ds is not None else None)
self.vocab,self.max_vocab,self.min_freq = vocab,max_vocab,min_freq
def process_one(self,item): return np.array(self.vocab.numericalize(item), dtype=np.int64)
def process(self, ds):
if self.vocab is None: self.vocab = Vocab.create(ds.items, self.max_vocab, self.min_freq)
ds.vocab = self.vocab
super().process(ds)
class OpenFileProcessor(PreProcessor):
"`PreProcessor` that opens the filenames and read the texts."
def process(self, ds:Collection): ds.items = array([self.process_one(item) for item in ds.items], dtype=np.object)
def process_one(self,item): return open_text(item) if isinstance(item, Path) else item
class TextList(ItemList):
"Basic `ItemList` for text data."
_bunch = TextClasDataBunch
_processor = [TokenizeProcessor, NumericalizeProcessor]
_is_lm = False
def __init__(self, items:Iterator, vocab:Vocab=None, pad_idx:int=1, sep=' ', **kwargs):
super().__init__(items, **kwargs)
self.vocab,self.pad_idx,self.sep = vocab,pad_idx,sep
self.copy_new += ['vocab', 'pad_idx', 'sep']
def get(self, i):
o = super().get(i)
return o if self.vocab is None else Text(o, self.vocab.textify(o, self.sep))
def label_for_lm(self, **kwargs):
"A special labelling method for language models."
self.__class__ = LMTextList
kwargs['label_cls'] = LMLabelList
return self.label_const(0, **kwargs)
def reconstruct(self, t:Tensor):
idx_min = (t != self.pad_idx).nonzero().min()
idx_max = (t != self.pad_idx).nonzero().max()
return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))
@classmethod
def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=text_extensions, vocab:Vocab=None,
processor:PreProcessor=None, **kwargs)->'TextList':
"Get the list of files in `path` that have a text suffix. `recurse` determines if we search subfolders."
processor = ifnone(processor, [OpenFileProcessor(), TokenizeProcessor(), NumericalizeProcessor(vocab=vocab)])
return super().from_folder(path=path, extensions=extensions, processor=processor, **kwargs)
def show_xys(self, xs, ys, max_len:int=70)->None:
"Show the `xs` (inputs) and `ys` (targets). `max_len` is the maximum number of tokens displayed."
from IPython.display import display, HTML
names = ['idx','text'] if self._is_lm else ['text','target']
items = []
for i, (x,y) in enumerate(zip(xs,ys)):
txt_x = ' '.join(x.text.split(' ')[:max_len]) if max_len is not None else x.text
items.append([i, txt_x] if self._is_lm else [txt_x, y])
items = np.array(items)
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
def show_xyzs(self, xs, ys, zs, max_len:int=70):
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions). `max_len` is the maximum number of tokens displayed."
from IPython.display import display, HTML
items,names = [],['text','target','prediction']
for i, (x,y,z) in enumerate(zip(xs,ys,zs)):
txt_x = ' '.join(x.text.split(' ')[:max_len]) if max_len is not None else x.text
items.append([txt_x, y, z])
items = np.array(items)
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
class LMLabelList(EmptyLabelList):
"Basic `ItemList` for dummy labels."
def __init__(self, items:Iterator, **kwargs):
super().__init__(items, **kwargs)
self.loss_func = CrossEntropyFlat()
class LMTextList(TextList):
"Special `TextList` for a language model."
_bunch = TextLMDataBunch
_is_lm = True
def _join_texts(texts:Collection[str], mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):
if not isinstance(texts, np.ndarray): texts = np.array(texts)
if is1d(texts): texts = texts[:,None]
df = pd.DataFrame({i:texts[:,i] for i in range(texts.shape[1])})
bos_tok = f'{BOS} ' if include_bos else ''
text_col = f'{bos_tok}{FLD} {1} ' + df[0].astype(str) if mark_fields else f'{bos_tok}' + df[0].astype(str)
for i in range(1,len(df.columns)):
text_col += (f' {FLD} {i+1} ' if mark_fields else ' ') + df[i].astype(str)
if include_eos: text_col = text_col + f' {EOS}'
return text_col.values
def apply_rules(text, pre_rules=None, post_rules=None):
"Apply `pre_rules` and `post_rules` to `text`"
text = text.strip(' ')
for r in ifnone(pre_rules, defaults.text_pre_rules): text = r(text)
toks = text.split()
for r in ifnone(post_rules, defaults.text_post_rules): toks = r(toks)
return ' '.join(toks)
def get_default_size(texts, max_vocab_sz):
"Either max_vocab_sz or one quarter of the number of unique words in `texts`"
cnt = Counter()
for t in texts:
cnt.update(t.split())
if len(cnt)//4 > max_vocab_sz: return max_vocab_sz
res = len(cnt)//4
while res%8 != 0: res+=1
return res
full_char_coverage_langs = ["bg", "cs", "da", "de", "el", "en", "es", "et", "fi", "fr", "ga", "hr", "hu",
"it","lt","lv","mt","nl","pl","pt","ro","sk","sl","sv"] # all European langs
def train_sentencepiece(texts:Collection[str], path:PathOrStr, pre_rules: ListRules=None, post_rules:ListRules=None,
vocab_sz:int=None, max_vocab_sz:int=30000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',
char_coverage=None, tmp_dir='tmp'):
"Train a sentencepiece tokenizer on `texts` and save it in `path/tmp_dir`"
from sentencepiece import SentencePieceTrainer
cache_dir = Path(path)/tmp_dir
os.makedirs(cache_dir, exist_ok=True)
if vocab_sz is None: vocab_sz=get_default_size(texts, max_vocab_sz)
raw_text_path = cache_dir / 'all_text.out'
with open(raw_text_path, 'w') as f: f.write("\n".join(texts))
spec_tokens = ['\u2581'+s for s in defaults.text_spec_tok]
SentencePieceTrainer.Train(" ".join([
f"--input={raw_text_path} --max_sentence_length={max_sentence_len}",
f"--character_coverage={ifnone(char_coverage, 0.99999 if lang in full_char_coverage_langs else 0.9998)}",
f"--unk_id={len(defaults.text_spec_tok)} --pad_id=-1 --bos_id=-1 --eos_id=-1",
f"--user_defined_symbols={','.join(spec_tokens)}",
f"--model_prefix={cache_dir/'spm'} --vocab_size={vocab_sz} --model_type={model_type}"]))
raw_text_path.unlink()
return cache_dir
class SPProcessor(PreProcessor):
"`PreProcessor` that tokenizes and numericalizes with `sentencepiece`"
def __init__(self, ds:ItemList=None, pre_rules: ListRules=None, post_rules:ListRules=None, vocab_sz:int=None,
max_vocab_sz:int=30000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',
char_coverage=None, tmp_dir='tmp', mark_fields:bool=False, include_bos:bool=True,
include_eos:bool=False, sp_model=None, sp_vocab=None, n_cpus:int=None):
try: from sentencepiece import SentencePieceTrainer,SentencePieceProcessor
except ImportError:
raise Exception('sentencepiece module is missing: run `pip install sentencepiece`')
self.pre_rules,self.post_rules = pre_rules,post_rules
self.mark_fields,self.include_bos,self.include_eos = mark_fields,include_bos,include_eos
self.sp_model,self.sp_vocab,self.n_cpus = sp_model,sp_vocab,ifnone(n_cpus,defaults.cpus)
self.train_func = partial(train_sentencepiece, pre_rules=pre_rules, post_rules=post_rules, vocab_sz=vocab_sz,
max_vocab_sz=max_vocab_sz, model_type=model_type, max_sentence_len=max_sentence_len, lang=lang,
char_coverage=char_coverage, tmp_dir=tmp_dir)
def process_one(self, item, join=True):
if join: text = _join_texts([item], self.mark_fields, self.include_bos, self.include_eos)[0]
text = apply_rules(text, pre_rules=self.pre_rules, post_rules=self.post_rules)
return self._encode_batch([text])[0]
def process(self, ds):
ds.items = _join_texts(ds.items, self.mark_fields, self.include_bos, self.include_eos)
ds.items = [apply_rules(t, pre_rules=self.pre_rules, post_rules=self.post_rules)
for t in progress_bar(ds.items, leave=False)]
if self.sp_model is None or self.sp_vocab is None:
cache_dir = self.train_func(ds.items, ds.path)
self.sp_model,self.sp_vocab = cache_dir/'spm.model',cache_dir/'spm.vocab'
if not getattr(self, 'vocab', False):
with open(self.sp_vocab, 'r') as f: self.vocab = Vocab([line.split('\t')[0] for line in f.readlines()])
if self.n_cpus <= 1: ds.items = self._encode_batch(ds.items)
else:
with ProcessPoolExecutor(self.n_cpus) as e:
ds.items = np.array(sum(e.map(self._encode_batch, partition_by_cores(ds.items, self.n_cpus)), []))
ds.vocab = self.vocab
def _encode_batch(self, texts):
from sentencepiece import SentencePieceProcessor
tok = SentencePieceProcessor()
tok.Load(str(self.sp_model))
return [np.array(tok.EncodeAsIds(t)) for t in texts]
@classmethod
def load(cls, path:PathOrStr, tmp_dir:PathOrStr='tmp', name:str='spm'):
cache_dir = Path(path)/tmp_dir
return cls(sp_model=cache_dir/f'{name}.model', sp_vocab=cache_dir/f'{name}.vocab')
================================================
FILE: fastai/text/interpret.py
================================================
from ..torch_core import *
from ..basic_data import *
from ..basic_train import *
from ..train import ClassificationInterpretation
import matplotlib.cm as cm
__all__ = ['TextClassificationInterpretation']
def value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple:
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str:
html_code,spans = [''], []
for p, a in zip(pieces, attns):
p = html.escape(p)
c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
spans.append(f'{p}')
html_code.append(sep.join(spans))
html_code.append('')
return ''.join(html_code)
def show_piece_attn(*args, **kwargs):
from IPython.display import display, HTML
display(HTML(piece_attn_html(*args, **kwargs)))
def _eval_dropouts(mod):
module_name = mod.__class__.__name__
if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False
for module in mod.children(): _eval_dropouts(module)
class TextClassificationInterpretation(ClassificationInterpretation):
"""Provides an interpretation of classification based on input sensitivity.
This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model.
"""
def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid):
super(TextClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
self.model = learn.model
@classmethod
def from_learner(cls, learn: Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None):
"Gets preds, y_true, losses to construct base class from a learner"
preds_res = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True, ordered=True)
return cls(learn, *preds_res)
def intrinsic_attention(self, text:str, class_id:int=None):
"""Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.
For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf
"""
self.model.train()
_eval_dropouts(self.model)
self.model.zero_grad()
self.model.reset()
ids = self.data.one_item(text)[0]
emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)
lstm_output = self.model[0].module(emb, from_embeddings=True)
self.model.eval()
cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)
if class_id is None: class_id = cl.argmax()
cl[0][class_id].backward()
attn = emb.grad.squeeze().abs().sum(dim=-1)
attn /= attn.max()
tokens = self.data.single_ds.reconstruct(ids[0])
return tokens, attn
def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str:
text, attn = self.intrinsic_attention(text, class_id)
return piece_attn_html(text.text.split(), to_np(attn), **kwargs)
def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None:
text, attn = self.intrinsic_attention(text, class_id)
show_piece_attn(text.text.split(), to_np(attn), **kwargs)
def show_top_losses(self, k:int, max_len:int=70)->None:
"""
Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of
actual class. `max_len` is the maximum number of tokens displayed.
"""
from IPython.display import display, HTML
items = []
tl_val,tl_idx = self.top_losses()
for i,idx in enumerate(tl_idx):
if k <= 0: break
k -= 1
tx,cl = self.data.dl(self.ds_type).dataset[idx]
cl = cl.data
classes = self.data.classes
txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text
tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}',
f'{self.preds[idx][cl]:.2f}']
items.append(tmp)
items = np.array(items)
names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability']
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
================================================
FILE: fastai/text/learner.py
================================================
'Model training for NLP'
from ..torch_core import *
from ..basic_train import *
from ..callbacks import *
from ..data_block import CategoryList
from ..basic_data import *
from ..datasets import *
from ..metrics import accuracy
from ..train import GradientClipping
from ..layers import *
from .models import *
from .transform import *
from .data import *
__all__ = ['RNNLearner', 'LanguageLearner', 'convert_weights', 'decode_spec_tokens', 'get_language_model', 'language_model_learner',
'MultiBatchEncoder', 'get_text_classifier', 'text_classifier_learner', 'PoolingLinearClassifier']
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,
'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,
'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},
Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER,
'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split,
'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split},
TransformerXL: {'hid_name':'d_model',
'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split,
'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}}
def convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights:
"Convert the model `wgts` to go with a new vocabulary."
dec_bias, enc_wgts = wgts.get('1.decoder.bias', None), wgts['0.encoder.weight']
wgts_m = enc_wgts.mean(0)
if dec_bias is not None: bias_m = dec_bias.mean(0)
new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_()
if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_()
for i,w in enumerate(itos_new):
r = stoi_wgts[w] if w in stoi_wgts else -1
new_w[i] = enc_wgts[r] if r>=0 else wgts_m
if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m
wgts['0.encoder.weight'] = new_w
if '0.encoder_dp.emb.weight' in wgts: wgts['0.encoder_dp.emb.weight'] = new_w.clone()
wgts['1.decoder.weight'] = new_w.clone()
if dec_bias is not None: wgts['1.decoder.bias'] = new_b
return wgts
class RNNLearner(Learner):
"Basic class for a `Learner` in NLP."
def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None,
alpha:float=2., beta:float=1., metrics=None, **learn_kwargs):
is_class = (hasattr(data.train_ds, 'y') and (isinstance(data.train_ds.y, CategoryList) or
isinstance(data.train_ds.y, LMLabelList)))
metrics = ifnone(metrics, ([accuracy] if is_class else []))
super().__init__(data, model, metrics=metrics, **learn_kwargs)
self.callbacks.append(RNNTrainer(self, alpha=alpha, beta=beta))
if clip: self.callback_fns.append(partial(GradientClipping, clip=clip))
if split_func: self.split(split_func)
def save_encoder(self, name:str):
"Save the encoder to `name` inside the model directory."
if is_pathlike(name): self._test_writeable_path()
encoder = get_model(self.model)[0]
if hasattr(encoder, 'module'): encoder = encoder.module
torch.save(encoder.state_dict(), self.path/self.model_dir/f'{name}.pth')
def load_encoder(self, name:str, device:torch.device=None):
"Load the encoder `name` from the model directory."
encoder = get_model(self.model)[0]
if device is None: device = self.data.device
if hasattr(encoder, 'module'): encoder = encoder.module
encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device))
self.freeze()
def load_pretrained(self, wgts_fname:str, itos_fname:str, strict:bool=True):
"Load a pretrained model and adapts it to the data vocabulary."
old_itos = pickle.load(open(itos_fname, 'rb'))
old_stoi = {v:k for k,v in enumerate(old_itos)}
wgts = torch.load(wgts_fname, map_location=lambda storage, loc: storage)
if 'model' in wgts: wgts = wgts['model']
wgts = convert_weights(wgts, old_stoi, self.data.train_ds.vocab.itos)
self.model.load_state_dict(wgts, strict=strict)
def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False, n_batch:Optional[int]=None,
pbar:Optional[PBar]=None, ordered:bool=False) -> List[Tensor]:
"Return predictions and targets on the valid, train, or test set, depending on `ds_type`."
self.model.reset()
if ordered: np.random.seed(42)
preds = super().get_preds(ds_type=ds_type, activ=activ, with_loss=with_loss, n_batch=n_batch, pbar=pbar)
if ordered and hasattr(self.dl(ds_type), 'sampler'):
np.random.seed(42)
sampler = [i for i in self.dl(ds_type).sampler]
reverse_sampler = np.argsort(sampler)
preds = [p[reverse_sampler] for p in preds]
return(preds)
def decode_spec_tokens(tokens):
new_toks,rule,arg = [],None,None
for t in tokens:
if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t
elif rule is None: new_toks.append(t)
elif rule == TK_MAJ:
new_toks.append(t[:1].upper() + t[1:].lower())
rule = None
elif rule == TK_UP:
new_toks.append(t.upper())
rule = None
elif arg is None:
try: arg = int(t)
except: rule = None
else:
if rule == TK_REP: new_toks.append(t * arg)
else: new_toks += [t] * arg
return new_toks
class LanguageLearner(RNNLearner):
"Subclass of RNNLearner for predictions."
def predict(self, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',
decoder=decode_spec_tokens):
"Return the `n_words` that come after `text`."
ds = self.data.single_dl.dataset
self.model.reset()
xb,yb = self.data.one_item(text)
new_idx = []
for _ in range(n_words): #progress_bar(range(n_words), leave=False):
res = self.pred_batch(batch=(xb,yb))[0][-1]
#if len(new_idx) == 0: self.model[0].select_hidden([0])
if no_unk: res[self.data.vocab.stoi[UNK]] = 0.
if min_p is not None:
if (res >= min_p).float().sum() == 0:
warn(f"There is no item with probability >= {min_p}, try a lower value.")
else: res[res < min_p] = 0.
if temperature != 1.: res.pow_(1 / temperature)
idx = torch.multinomial(res, 1).item()
new_idx.append(idx)
xb = xb.new_tensor([idx])[None]
return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None)))
def beam_search(self, text:str, n_words:int, no_unk:bool=True, top_k:int=10, beam_sz:int=1000, temperature:float=1.,
sep:str=' ', decoder=decode_spec_tokens):
"Return the `n_words` that come after `text` using beam search."
ds = self.data.single_dl.dataset
self.model.reset()
self.model.eval()
xb, yb = self.data.one_item(text)
nodes = None
nodes = xb.clone()
scores = xb.new_zeros(1).float()
with torch.no_grad():
for k in progress_bar(range(n_words), leave=False):
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1)
if no_unk: out[:,self.data.vocab.stoi[UNK]] = -float('Inf')
values, indices = out.topk(top_k, dim=-1)
scores = (-values + scores[:,None]).view(-1)
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1)
sort_idx = scores.argsort()[:beam_sz]
scores = scores[sort_idx]
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)),
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2)
nodes = nodes.view(-1, nodes.size(2))[sort_idx]
self.model[0].select_hidden(indices_idx[sort_idx])
xb = nodes[:,-1][:,None]
if temperature != 1.: scores.div_(temperature)
node_idx = torch.multinomial(torch.exp(-scores), 1).item()
return text + sep + sep.join(decoder(self.data.vocab.textify([i.item() for i in nodes[node_idx][1:] ], sep=None)))
def show_results(self, ds_type=DatasetType.Valid, rows:int=5, max_len:int=20):
from IPython.display import display, HTML
"Show `rows` result of predictions on `ds_type` dataset."
ds = self.dl(ds_type).dataset
x,y = self.data.one_batch(ds_type, detach=False, denorm=False)
preds = self.pred_batch(batch=(x,y))
y = y.view(*x.size())
z = preds.view(*x.size(),-1).argmax(dim=2)
xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)]
ys = [ds.x.reconstruct(grab_idx(y, i)) for i in range(rows)]
zs = [ds.x.reconstruct(grab_idx(z, i)) for i in range(rows)]
items,names = [],['text', 'target', 'pred']
for i, (x,y,z) in enumerate(zip(xs,ys,zs)):
txt_x = ' '.join(x.text.split(' ')[:max_len])
txt_y = ' '.join(y.text.split(' ')[max_len-1:2*max_len-1])
txt_z = ' '.join(z.text.split(' ')[max_len-1:2*max_len-1])
items.append([txt_x, txt_y, txt_z])
items = np.array(items)
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
def get_language_model(arch:Callable, vocab_sz:int, config:dict=None, drop_mult:float=1.):
"Create a language model from `arch` and its `config`, maybe `pretrained`."
meta = _model_meta[arch]
config = ifnone(config, meta['config_lm']).copy()
for k in config.keys():
if k.endswith('_p'): config[k] *= drop_mult
tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])
init = config.pop('init') if 'init' in config else None
encoder = arch(vocab_sz, **config)
enc = encoder.encoder if tie_weights else None
decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias)
model = SequentialRNN(encoder, decoder)
return model if init is None else model.apply(init)
def language_model_learner(data:DataBunch, arch, config:dict=None, drop_mult:float=1., pretrained:bool=True,
pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner':
"Create a `Learner` with a language model from `data` and `arch`."
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult)
meta = _model_meta[arch]
learn = LanguageLearner(data, model, split_func=meta['split_lm'], **learn_kwargs)
url = 'url_bwd' if data.backwards else 'url'
if pretrained or pretrained_fnames:
if pretrained_fnames is not None:
fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]
else:
if url not in meta:
warn("There are no pretrained weights for that architecture yet!")
return learn
model_path = untar_data(meta[url] , data=False)
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
learn.load_pretrained(*fnames)
learn.freeze()
return learn
def masked_concat_pool(outputs, mask):
"Pool MultiBatchEncoder outputs into one vector [last_hidden, max_pool, avg_pool]."
output = outputs[-1]
avg_pool = output.masked_fill(mask[:, :, None], 0).mean(dim=1)
avg_pool *= output.size(1) / (output.size(1)-mask.type(avg_pool.dtype).sum(dim=1))[:,None]
max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]
x = torch.cat([output[:,-1], max_pool, avg_pool], 1)
return x
class PoolingLinearClassifier(Module):
"Create a linear classifier with pooling."
def __init__(self, layers:Collection[int], drops:Collection[float]):
mod_layers = []
if len(drops) != len(layers)-1: raise ValueError("Number of layers and dropout values do not match.")
activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]
for n_in, n_out, p, actn in zip(layers[:-1], layers[1:], drops, activs):
mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)
self.layers = nn.Sequential(*mod_layers)
def forward(self, input:Tuple[Tensor,Tensor, Tensor])->Tuple[Tensor,Tensor,Tensor]:
raw_outputs,outputs,mask = input
x = masked_concat_pool(outputs, mask)
x = self.layers(x)
return x, raw_outputs, outputs
class MultiBatchEncoder(Module):
"Create an encoder over `module` that can process a full sentence."
def __init__(self, bptt:int, max_len:int, module:nn.Module, pad_idx:int=1):
self.max_len,self.bptt,self.module,self.pad_idx = max_len,bptt,module,pad_idx
def concat(self, arrs:Collection[Tensor])->Tensor:
"Concatenate the `arrs` along the batch dimension."
return [torch.cat([l[si] for l in arrs], dim=1) for si in range_of(arrs[0])]
def reset(self):
if hasattr(self.module, 'reset'): self.module.reset()
def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]:
bs,sl = input.size()
self.reset()
raw_outputs,outputs,masks = [],[],[]
for i in range(0, sl, self.bptt):
r, o = self.module(input[:,i: min(i+self.bptt, sl)])
if i>(sl-self.max_len):
masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx)
raw_outputs.append(r)
outputs.append(o)
return self.concat(raw_outputs),self.concat(outputs),torch.cat(masks,dim=1)
def get_text_classifier(arch:Callable, vocab_sz:int, n_class:int, bptt:int=70, max_len:int=20*70, config:dict=None,
drop_mult:float=1., lin_ftrs:Collection[int]=None, ps:Collection[float]=None,
pad_idx:int=1) -> nn.Module:
"Create a text classifier from `arch` and its `config`, maybe `pretrained`."
meta = _model_meta[arch]
config = ifnone(config, meta['config_clas']).copy()
for k in config.keys():
if k.endswith('_p'): config[k] *= drop_mult
if lin_ftrs is None: lin_ftrs = [50]
if ps is None: ps = [0.1]*len(lin_ftrs)
layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]
ps = [config.pop('output_p')] + ps
init = config.pop('init') if 'init' in config else None
encoder = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config), pad_idx=pad_idx)
model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps))
return model if init is None else model.apply(init)
def text_classifier_learner(data:DataBunch, arch:Callable, bptt:int=70, max_len:int=70*20, config:dict=None,
pretrained:bool=True, drop_mult:float=1., lin_ftrs:Collection[int]=None,
ps:Collection[float]=None, **learn_kwargs) -> 'TextClassifierLearner':
"Create a `Learner` with a text classifier from `data` and `arch`."
model = get_text_classifier(arch, len(data.vocab.itos), data.c, bptt=bptt, max_len=max_len,
config=config, drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps)
meta = _model_meta[arch]
learn = RNNLearner(data, model, split_func=meta['split_clas'], **learn_kwargs)
if pretrained:
if 'url' not in meta:
warn("There are no pretrained weights for that architecture yet!")
return learn
model_path = untar_data(meta['url'], data=False)
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
learn.load_pretrained(*fnames, strict=False)
learn.freeze()
return learn
================================================
FILE: fastai/text/models/__init__.py
================================================
from .awd_lstm import *
from .transformer import *
__all__ = [*awd_lstm.__all__, *transformer.__all__]
================================================
FILE: fastai/text/models/awd_lstm.py
================================================
from ...torch_core import *
from ...layers import *
from ...train import ClassificationInterpretation
from ...basic_train import *
from ...basic_data import *
from ..data import TextClasDataBunch
import matplotlib.cm as cm
__all__ = ['EmbeddingDropout', 'LinearDecoder', 'AWD_LSTM', 'RNNDropout',
'SequentialRNN', 'WeightDropout', 'dropout_mask', 'awd_lstm_lm_split', 'awd_lstm_clas_split',
'awd_lstm_lm_config', 'awd_lstm_clas_config', 'TextClassificationInterpretation']
def dropout_mask(x:Tensor, sz:Collection[int], p:float):
"Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element."
return x.new(*sz).bernoulli_(1-p).div_(1-p)
class RNNDropout(Module):
"Dropout with probability `p` that is consistent on the seq_len dimension."
def __init__(self, p:float=0.5): self.p=p
def forward(self, x:Tensor)->Tensor:
if not self.training or self.p == 0.: return x
m = dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)
return x * m
class WeightDropout(Module):
"A module that warps another layer in which some weights will be replaced by 0 during training."
def __init__(self, module:nn.Module, weight_p:float, layer_names:Collection[str]=['weight_hh_l0']):
self.module,self.weight_p,self.layer_names = module,weight_p,layer_names
for layer in self.layer_names:
#Makes a copy of the weights of the selected layers.
w = getattr(self.module, layer)
self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)
def _setweights(self):
"Apply dropout to the raw weights."
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)
def forward(self, *args:ArgStar):
self._setweights()
with warnings.catch_warnings():
#To avoid the warning that comes because the weights aren't flattened.
warnings.simplefilter("ignore")
return self.module.forward(*args)
def reset(self):
for layer in self.layer_names:
raw_w = getattr(self, f'{layer}_raw')
self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=False)
if hasattr(self.module, 'reset'): self.module.reset()
class EmbeddingDropout(Module):
"Apply dropout with probabily `embed_p` to an embedding layer `emb`."
def __init__(self, emb:nn.Module, embed_p:float):
self.emb,self.embed_p = emb,embed_p
self.pad_idx = self.emb.padding_idx
if self.pad_idx is None: self.pad_idx = -1
def forward(self, words:LongTensor, scale:Optional[float]=None)->Tensor:
if self.training and self.embed_p != 0:
size = (self.emb.weight.size(0),1)
mask = dropout_mask(self.emb.weight.data, size, self.embed_p)
masked_embed = self.emb.weight * mask
else: masked_embed = self.emb.weight
if scale: masked_embed.mul_(scale)
return F.embedding(words, masked_embed, self.pad_idx, self.emb.max_norm,
self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)
class AWD_LSTM(Module):
"AWD-LSTM/QRNN inspired by https://arxiv.org/abs/1708.02182."
initrange=0.1
def __init__(self, vocab_sz:int, emb_sz:int, n_hid:int, n_layers:int, pad_token:int=1, hidden_p:float=0.2,
input_p:float=0.6, embed_p:float=0.1, weight_p:float=0.5, qrnn:bool=False, bidir:bool=False):
self.bs,self.qrnn,self.emb_sz,self.n_hid,self.n_layers = 1,qrnn,emb_sz,n_hid,n_layers
self.n_dir = 2 if bidir else 1
self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)
if self.qrnn:
#Using QRNN requires an installation of cuda
from .qrnn import QRNN
self.rnns = [QRNN(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, 1,
save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True, bidirectional=bidir)
for l in range(n_layers)]
for rnn in self.rnns:
rnn.layers[0].linear = WeightDropout(rnn.layers[0].linear, weight_p, layer_names=['weight'])
else:
self.rnns = [nn.LSTM(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, 1,
batch_first=True, bidirectional=bidir) for l in range(n_layers)]
self.rnns = [WeightDropout(rnn, weight_p) for rnn in self.rnns]
self.rnns = nn.ModuleList(self.rnns)
self.encoder.weight.data.uniform_(-self.initrange, self.initrange)
self.input_dp = RNNDropout(input_p)
self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])
def forward(self, input:Tensor, from_embeddings:bool=False)->Tuple[Tensor,Tensor]:
if from_embeddings: bs,sl,es = input.size()
else: bs,sl = input.size()
if bs!=self.bs:
self.bs=bs
self.reset()
raw_output = self.input_dp(input if from_embeddings else self.encoder_dp(input))
new_hidden,raw_outputs,outputs = [],[],[]
for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):
raw_output, new_h = rnn(raw_output, self.hidden[l])
new_hidden.append(new_h)
raw_outputs.append(raw_output)
if l != self.n_layers - 1: raw_output = hid_dp(raw_output)
outputs.append(raw_output)
self.hidden = to_detach(new_hidden, cpu=False)
return raw_outputs, outputs
def _one_hidden(self, l:int)->Tensor:
"Return one hidden state."
nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir
return one_param(self).new(self.n_dir, self.bs, nh).zero_()
def select_hidden(self, idxs):
if self.qrnn: self.hidden = [h[:,idxs,:] for h in self.hidden]
else: self.hidden = [(h[0][:,idxs,:],h[1][:,idxs,:]) for h in self.hidden]
self.bs = len(idxs)
def reset(self):
"Reset the hidden states."
[r.reset() for r in self.rnns if hasattr(r, 'reset')]
if self.qrnn: self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]
else: self.hidden = [(self._one_hidden(l), self._one_hidden(l)) for l in range(self.n_layers)]
class LinearDecoder(Module):
"To go on top of a RNNCore module and create a Language Model."
initrange=0.1
def __init__(self, n_out:int, n_hid:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True):
self.decoder = nn.Linear(n_hid, n_out, bias=bias)
self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
self.output_dp = RNNDropout(output_p)
if bias: self.decoder.bias.data.zero_()
if tie_encoder: self.decoder.weight = tie_encoder.weight
def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
raw_outputs, outputs = input
output = self.output_dp(outputs[-1])
decoded = self.decoder(output)
return decoded, raw_outputs, outputs
class SequentialRNN(nn.Sequential):
"A sequential module that passes the reset call to its children."
def reset(self):
for c in self.children():
if hasattr(c, 'reset'): c.reset()
def awd_lstm_lm_split(model:nn.Module) -> List[nn.Module]:
"Split a RNN `model` in groups for differential learning rates."
groups = [[rnn, dp] for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)]
return groups + [[model[0].encoder, model[0].encoder_dp, model[1]]]
def awd_lstm_clas_split(model:nn.Module) -> List[nn.Module]:
"Split a RNN `model` in groups for differential learning rates."
groups = [[model[0].module.encoder, model[0].module.encoder_dp]]
groups += [[rnn, dp] for rnn, dp in zip(model[0].module.rnns, model[0].module.hidden_dps)]
return groups + [[model[1]]]
awd_lstm_lm_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.1,
hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)
awd_lstm_clas_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.4,
hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5)
def value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple:
"Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
c = cmap(x)
rgb = (np.array(c[:-1]) * 255).astype(int)
a = c[-1] * alpha_mult
return tuple(rgb.tolist() + [a])
def piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str:
html_code,spans = [''], []
for p, a in zip(pieces, attns):
p = html.escape(p)
c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
spans.append(f'{p}')
html_code.append(sep.join(spans))
html_code.append('')
return ''.join(html_code)
def show_piece_attn(*args, **kwargs):
from IPython.display import display, HTML
display(HTML(piece_attn_html(*args, **kwargs)))
def _eval_dropouts(mod):
module_name = mod.__class__.__name__
if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False
for module in mod.children(): _eval_dropouts(module)
class TextClassificationInterpretation(ClassificationInterpretation):
"""Provides an interpretation of classification based on input sensitivity.
This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model.
"""
def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid):
super().__init__(learn,preds,y_true,losses,ds_type)
self.model = learn.model
def intrinsic_attention(self, text:str, class_id:int=None):
"""Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.
For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf
"""
self.model.train()
_eval_dropouts(self.model)
self.model.zero_grad()
self.model.reset()
ids = self.data.one_item(text)[0]
emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)
lstm_output = self.model[0].module(emb, from_embeddings=True)
self.model.eval()
cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)
if class_id is None: class_id = cl.argmax()
cl[0][class_id].backward()
attn = emb.grad.squeeze().abs().sum(dim=-1)
attn /= attn.max()
tokens = self.data.single_ds.reconstruct(ids[0])
return tokens, attn
def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str:
text, attn = self.intrinsic_attention(text, class_id)
return piece_attn_html(text.text.split(), to_np(attn), **kwargs)
def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None:
text, attn = self.intrinsic_attention(text, class_id)
show_piece_attn(text.text.split(), to_np(attn), **kwargs)
def show_top_losses(self, k:int, max_len:int=70)->None:
"""
Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of
actual class. `max_len` is the maximum number of tokens displayed.
"""
from IPython.display import display, HTML
items = []
tl_val,tl_idx = self.top_losses()
for i,idx in enumerate(tl_idx):
if k <= 0: break
k -= 1
tx,cl = self.data.dl(self.ds_type).dataset[idx]
cl = cl.data
classes = self.data.classes
txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text
tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}',
f'{self.preds[idx][cl]:.2f}']
items.append(tmp)
items = np.array(items)
names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability']
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
================================================
FILE: fastai/text/models/bwd_forget_mult_cuda.cpp
================================================
#include
#include
// CUDA forward declarations
at::Tensor bwd_forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor bwd_forget_mult_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);
return bwd_forget_mult_cuda_forward(x, f, output, batch_first);
}
std::vector bwd_forget_mult_cuda_backward(at::Tensor x, at::Tensor f, at::Tensor output,
at::Tensor grad_output, bool batch_first);
std::vector bwd_forget_mult_backward(at::Tensor x, at::Tensor f, at::Tensor output,
at::Tensor grad_output, bool batch_first) {
CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);
return bwd_forget_mult_cuda_backward(x, f, output, grad_output, batch_first);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &bwd_forget_mult_forward, "BwdForgetMult forward (CUDA)");
m.def("backward", &bwd_forget_mult_backward, "BwdForgetMult backward (CUDA)");
}
================================================
FILE: fastai/text/models/bwd_forget_mult_cuda_kernel.cu
================================================
#include
#include
#include
#include
#include
template
__global__ void bwd_forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,
const scalar_t* __restrict__ f, scalar_t* __restrict__ output,
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
/*
Note: output is assumed to be one timestep longer than f or x where output[seq_length] = h_{+1}
This means output array has a size of seq_length+1 on the word dimension
*/
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
if (hid < n_hidden && bid < batch_size){
for (int ts = seq_length-1; ts >= 0; ts--) {
int i = 0;
int dst_i = 0;
int dst_iplus1 = 0;
if (batch_first){
i = bid * n_hidden * seq_length + (ts+0) * n_hidden + hid;
dst_i = bid * n_hidden * (seq_length+1) + (ts+0) * n_hidden + hid;
dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;
}
else {
i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_iplus1 = (ts+1) * n_hidden * batch_size + bid * n_hidden + hid;
}
output[dst_i] = f[i] * x[i];
output[dst_i] += (1 - f[i]) * output[dst_iplus1];
}
}
}
template
__global__ void bwd_forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,
const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,
const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,
scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
double running_f = 0;
if(hid < n_hidden && bid < batch_size){
for (int ts = 0; ts < seq_length; ts++) {
int i = 0;
int dst_iplus1 = 0;
if (batch_first){
i = bid * n_hidden * seq_length + (ts+0) * n_hidden + hid;
dst_iplus1 = bid * n_hidden * (seq_length+1) + (ts+1) * n_hidden + hid;
}
else {
i = (ts+0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_iplus1 = (ts+1) * n_hidden * batch_size + bid * n_hidden + hid;
}
running_f += grad_output[i];
grad_x[i] = f[i] * running_f;
grad_f[i] = (x[i] - output[dst_iplus1]) * running_f;
// The line below is likely more numerically stable than (1 - f[i]) * running_f;
running_f = running_f - f[i] * running_f;
}
grad_h[bid * n_hidden + hid] = running_f;
}
}
at::Tensor bwd_forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
const auto n_hidden = x.size(2);
const int threads = 1024;
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(x.type(), "bwd_forget_mult_cuda_forward", ([&] {
bwd_forget_mult_cuda_forward_kernel<<>>(
x.data(), f.data(), output.data(), batch_size,
seq_length, n_hidden, batch_first);
}));
THCudaCheck(cudaGetLastError());
return output;
}
std::vector bwd_forget_mult_cuda_backward(at::Tensor x, at::Tensor f,
at::Tensor output, at::Tensor grad_output, bool batch_first) {
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
const auto n_hidden = x.size(2);
auto grad_x = at::zeros_like(x);
auto grad_f = at::zeros_like(x);
auto grad_h = at::zeros({batch_size, n_hidden}, x.options());
const int threads = 1024;
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(x.type(), "bwd_forget_mult_cuda_forward", ([&] {
bwd_forget_mult_cuda_backward_kernel<<>>(
x.data(), f.data(), output.data(), grad_output.data(),
grad_x.data(), grad_f.data(), grad_h.data(), batch_size,
seq_length, n_hidden, batch_first);
}));
THCudaCheck(cudaGetLastError());
return {grad_x, grad_f, grad_h};
}
================================================
FILE: fastai/text/models/forget_mult_cuda.cpp
================================================
#include
#include
// CUDA forward declarations
at::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at::Tensor forget_mult_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);
return forget_mult_cuda_forward(x, f, output, batch_first);
}
std::vector forget_mult_cuda_backward(at::Tensor x, at::Tensor f, at::Tensor output,
at::Tensor grad_output, bool batch_first);
std::vector forget_mult_backward(at::Tensor x, at::Tensor f, at::Tensor output,
at::Tensor grad_output, bool batch_first) {
CHECK_INPUT(x); CHECK_INPUT(f); CHECK_INPUT(output);
return forget_mult_cuda_backward(x, f, output, grad_output, batch_first);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forget_mult_forward, "ForgetMult forward (CUDA)");
m.def("backward", &forget_mult_backward, "ForgetMult backward (CUDA)");
}
================================================
FILE: fastai/text/models/forget_mult_cuda_kernel.cu
================================================
#include
#include
#include
#include
#include
template
__global__ void forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,
const scalar_t* __restrict__ f, scalar_t* __restrict__ output,
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
/*
Note: output is assumed to be one timestep longer than f or x where output[0] = h_{-1}
This means output array has a size of seq_length+1 on the word dimension
*/
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
if (hid < n_hidden && bid < batch_size){
for (int ts = 1; ts < seq_length + 1; ts++) {
int i = 0;
int dst_i = 0;
int dst_iminus1 = 0;
if (batch_first){
i = bid * n_hidden * seq_length + (ts-1) * n_hidden + hid;
dst_i = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;
dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;
}
else {
i = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
dst_i = (ts-0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_iminus1 = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
}
output[dst_i] = f[i] * x[i];
output[dst_i] += (1 - f[i]) * output[dst_iminus1];
}
}
}
template
__global__ void forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,
const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,
const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,
scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
double running_f = 0;
if(hid < n_hidden && bid < batch_size){
for (int ts = seq_length; ts >= 0 + 1; ts--) {
int i = 0;
int dst_i = 0;
int dst_iminus1 = 0;
if (batch_first){
i = bid * n_hidden * seq_length + (ts-1) * n_hidden + hid;
dst_i = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;
dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;
}
else {
i = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
dst_i = (ts-0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_iminus1 = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
}
running_f += grad_output[i];
grad_x[i] = f[i] * running_f;
grad_f[i] = (x[i] - output[dst_iminus1]) * running_f;
// The line below is likely more numerically stable than (1 - f[i]) * running_f;
running_f = running_f - f[i] * running_f;
}
grad_h[bid * n_hidden + hid] = running_f;
}
}
at::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
const auto n_hidden = x.size(2);
const int threads = 1024;
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(x.type(), "forget_mult_cuda_forward", ([&] {
forget_mult_cuda_forward_kernel<<>>(
x.data(), f.data(), output.data(), batch_size,
seq_length, n_hidden, batch_first);
}));
THCudaCheck(cudaGetLastError());
return output;
}
std::vector forget_mult_cuda_backward(at::Tensor x, at::Tensor f,
at::Tensor output, at::Tensor grad_output, bool batch_first) {
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
const auto n_hidden = x.size(2);
auto grad_x = at::zeros_like(x);
auto grad_f = at::zeros_like(x);
auto grad_h = at::zeros({batch_size, n_hidden}, x.options());
const int threads = 1024;
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(x.type(), "forget_mult_cuda_forward", ([&] {
forget_mult_cuda_backward_kernel<<>>(
x.data(), f.data(), output.data(), grad_output.data(),
grad_x.data(), grad_f.data(), grad_h.data(), batch_size,
seq_length, n_hidden, batch_first);
}));
THCudaCheck(cudaGetLastError());
return {grad_x, grad_f, grad_h};
}
================================================
FILE: fastai/text/models/qrnn.py
================================================
from ...torch_core import *
from torch.utils.cpp_extension import load
from torch.autograd import Function
__all__ = ['QRNNLayer', 'QRNN']
import fastai
if torch.cuda.is_available():
fastai_path = Path(fastai.__path__[0])/'text'/'models'
files = ['forget_mult_cuda.cpp', 'forget_mult_cuda_kernel.cu']
forget_mult_cuda = load(name='forget_mult_cuda', sources=[fastai_path/f for f in files])
files = ['bwd_forget_mult_cuda.cpp', 'bwd_forget_mult_cuda_kernel.cu']
bwd_forget_mult_cuda = load(name='bwd_forget_mult_cuda', sources=[fastai_path/f for f in files])
def dispatch_cuda(cuda_class, cpu_func, x):
return cuda_class.apply if x.device.type == 'cuda' else cpu_func
class ForgetMultGPU(Function):
@staticmethod
def forward(ctx, x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True):
if batch_first:
batch_size, seq_size, hidden_size = f.size()
output = f.new_zeros(batch_size, seq_size + 1, hidden_size)
if hidden_init is not None: output[:, 0] = hidden_init
else: output.zero_()
else:
seq_size, batch_size, hidden_size = f.size()
output = f.new(seq_size + 1, batch_size, hidden_size)
if hidden_init is not None: output[0] = hidden_init
else: output.zero_()
output = forget_mult_cuda.forward(x, f, output, batch_first)
ctx.save_for_backward(x, f, hidden_init, output)
ctx.batch_first = batch_first
return output[:,1:] if batch_first else output[1:]
@staticmethod
def backward(ctx, grad_output):
x, f, hidden_init, output = ctx.saved_tensors
grad_x, grad_f, grad_h = forget_mult_cuda.backward(x, f, output, grad_output, ctx.batch_first)
return (grad_x, grad_f, (None if hidden_init is None else grad_h), None)
class BwdForgetMultGPU(Function):
@staticmethod
def forward(ctx, x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True):
if batch_first:
batch_size, seq_size, hidden_size = f.size()
output = f.new(batch_size, seq_size + 1, hidden_size)
if hidden_init is not None: output[:, -1] = hidden_init
else: output.zero_()
else:
seq_size, batch_size, hidden_size = f.size()
output = f.new(seq_size + 1, batch_size, hidden_size)
if hidden_init is not None: output[-1] = hidden_init
else: output.zero_()
output = bwd_forget_mult_cuda.forward(x, f, output, batch_first)
ctx.save_for_backward(x, f, hidden_init, output)
ctx.batch_first = batch_first
return output[:,:-1] if batch_first else output[:-1]
@staticmethod
def backward(ctx, grad_output:Tensor):
x, f, hidden_init, output = ctx.saved_tensors
grad_x, grad_f, grad_h = bwd_forget_mult_cuda.backward(x, f, output, grad_output, ctx.batch_first)
return (grad_x, grad_f, (None if hidden_init is None else grad_h), None)
def forget_mult_CPU(x:Tensor, f:Tensor, hidden_init:Optional[Tensor]=None, batch_first:bool=True, backward:bool=False):
result = []
dim = (1 if batch_first else 0)
forgets = f.split(1, dim=dim)
inputs = x.split(1, dim=dim)
prev_h = None if hidden_init is None else hidden_init.unsqueeze(1 if batch_first else 0)
idx_range = range(len(inputs)-1,-1,-1) if backward else range(len(inputs))
for i in idx_range:
prev_h = inputs[i] * forgets[i] if prev_h is None else inputs[i] * forgets[i] + (1-forgets[i]) * prev_h
if backward: result.insert(0, prev_h)
else: result.append(prev_h)
return torch.cat(result, dim=dim)
class QRNNLayer(Module):
"Apply a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence."
def __init__(self, input_size:int, hidden_size:int=None, save_prev_x:bool=False, zoneout:float=0, window:int=1,
output_gate:bool=True, batch_first:bool=True, backward:bool=False):
super().__init__()
assert window in [1, 2], "This QRNN implementation currently only handles convolutional window of size 1 or size 2"
self.save_prev_x,self.zoneout,self.window = save_prev_x,zoneout,window
self.output_gate,self.batch_first,self.backward = output_gate,batch_first,backward
hidden_size = ifnone(hidden_size, input_size)
#One large matmul with concat is faster than N small matmuls and no concat
mult = (3 if output_gate else 2)
self.linear = nn.Linear(window * input_size, mult * hidden_size)
self.prevX = None
def reset(self):
# If you are saving the previous value of x, you should call this when starting with a new state
self.prevX = None
def forward(self, inp, hid=None):
y = self.linear(self._get_source(inp))
if self.output_gate: z_gate,f_gate,o_gate = y.chunk(3, dim=2)
else: z_gate,f_gate = y.chunk(2, dim=2)
z_gate.tanh_()
f_gate.sigmoid_()
if self.zoneout and self.training:
mask = dropout_mask(f_gate, f_gate.size(), self.zoneout).requires_grad_(False)
f_gate = f_gate * mask
z_gate,f_gate = z_gate.contiguous(),f_gate.contiguous()
if self.backward: forget_mult = dispatch_cuda(BwdForgetMultGPU, partial(forget_mult_CPU, backward=True), inp)
else: forget_mult = dispatch_cuda(ForgetMultGPU, forget_mult_CPU, inp)
c_gate = forget_mult(z_gate, f_gate, hid, self.batch_first)
output = torch.sigmoid(o_gate) * c_gate if self.output_gate else c_gate
if self.window > 1 and self.save_prev_x:
if self.backward: self.prevX = (inp[:, :1] if self.batch_first else inp[:1]).detach()
else: self.prevX = (inp[:, -1:] if self.batch_first else inp[-1:]).detach()
idx = 0 if self.backward else -1
return output, (c_gate[:, idx] if self.batch_first else c_gate[idx])
def _get_source(self, inp):
if self.window == 1: return inp
dim = (1 if self.batch_first else 0)
inp_shift = [torch.zeros_like(inp[:,:1] if self.batch_first else inp[:1]) if self.prevX is None else self.prevX]
if self.backward: inp_shift.insert(0,inp[:,1:] if self.batch_first else inp[1:])
else: inp_shift.append(inp[:,:-1] if self.batch_first else inp[:-1])
inp_shift = torch.cat(inp_shift, dim)
return torch.cat([inp, inp_shift], 2)
class QRNN(Module):
"Apply a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence."
def __init__(self, input_size:int, hidden_size:int, n_layers:int=1, bias:bool=True, batch_first:bool=True,
dropout:float=0, bidirectional:bool=False, save_prev_x:bool=False, zoneout:float=0, window:int=None,
output_gate:bool=True):
assert not (save_prev_x and bidirectional), "Can't save the previous X with bidirectional."
assert bias == True, 'Removing underlying bias is not yet supported'
super().__init__()
kwargs = dict(batch_first=batch_first, zoneout=zoneout, output_gate=output_gate)
self.layers = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, save_prev_x=save_prev_x,
window=((2 if l ==0 else 1) if window is None else window), **kwargs)
for l in range(n_layers)])
if bidirectional:
self.layers_bwd = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size,
backward=True, window=((2 if l ==0 else 1) if window is None else window),
**kwargs) for l in range(n_layers)])
self.n_layers,self.batch_first,self.dropout,self.bidirectional = n_layers,batch_first,dropout,bidirectional
def reset(self):
"If your convolutional window is greater than 1 and you save previous xs, you must reset at the beginning of each new sequence."
for layer in self.layers: layer.reset()
if self.bidirectional:
for layer in self.layers_bwd: layer.reset()
def forward(self, inp, hid=None):
new_hid = []
if self.bidirectional: inp_bwd = inp.clone()
for i, layer in enumerate(self.layers):
inp, h = layer(inp, None if hid is None else hid[2*i if self.bidirectional else i])
new_hid.append(h)
if self.bidirectional:
inp_bwd, h_bwd = self.layers_bwd[i](inp_bwd, None if hid is None else hid[2*i+1])
new_hid.append(h_bwd)
if self.dropout != 0 and i < len(self.layers) - 1:
for o in ([inp, inp_bwd] if self.bidirectional else [inp]):
o = F.dropout(o, p=self.dropout, training=self.training, inplace=False)
if self.bidirectional: inp = torch.cat([inp, inp_bwd], dim=2)
return inp, torch.stack(new_hid, 0)
================================================
FILE: fastai/text/models/transformer.py
================================================
from ...torch_core import *
from ...layers import *
from .awd_lstm import RNNDropout, LinearDecoder, SequentialRNN
__all__ = ['Activation', 'PositionalEncoding', 'GeLU', 'Swish', 'feed_forward', 'MultiHeadAttention', 'MultiHeadRelativeAttention',
'DecoderLayer', 'Transformer', 'TransformerXL', 'tfmer_lm_config', 'tfmer_clas_config', 'tfmer_lm_split', 'tfmer_clas_split',
'tfmerXL_lm_config', 'tfmerXL_clas_config', 'tfmerXL_lm_split', 'tfmerXL_clas_split']
Activation = Enum('Activation', 'ReLU Swish GeLU')
class PositionalEncoding(Module):
"Encode the position with a sinusoid."
def __init__(self, d:int): self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d)))
def forward(self, pos:Tensor):
inp = torch.ger(pos, self.freq)
enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
return enc
class GeLU(Module):
def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Swish(Module):
def forward(self, x): return x * torch.sigmoid(x)
_activ_func = {Activation.ReLU:nn.ReLU(inplace=True), Activation.GeLU:GeLU(), Activation.Swish: Swish()}
def feed_forward(d_model:int, d_ff:int, ff_p:float=0., act:Activation=Activation.ReLU, double_drop:bool=True):
layers = [nn.Linear(d_model, d_ff), _activ_func[act]]
if double_drop: layers.append(nn.Dropout(ff_p))
return SequentialEx(*layers, nn.Linear(d_ff, d_model), nn.Dropout(ff_p), MergeLayer(), nn.LayerNorm(d_model))
class MultiHeadAttention(Module):
"MutiHeadAttention."
def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0., attn_p:float=0., bias:bool=True,
scale:bool=True):
d_head = ifnone(d_head, d_model//n_heads)
self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
self.ln = nn.LayerNorm(d_model)
def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))
def _apply_attention(self, x:Tensor, mask:Tensor=None):
bs,x_len = x.size(0),x.size(1)
wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
attn_score = torch.matmul(wq, wk)
if self.scale: attn_score.div_(self.d_head ** 0.5)
if mask is not None:
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
attn_vec = torch.matmul(attn_prob, wv)
return attn_vec.permute(0, 2, 1, 3).contiguous().contiguous().view(bs, x_len, -1)
def _attention_einsum(self, x, mask=None):
# Permute and matmul is a little bit faster but this implementation is more readable
bs,x_len = x.size(0),x.size(1)
wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
attn_score = torch.einsum('bind,bjnd->bijn', (wq, wk))
if self.scale: attn_score.mul_(1/(self.d_head ** 0.5))
if mask is not None:
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
attn_prob = self.drop_att(F.softmax(attn_score, dim=2))
attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))
return attn_vec.contiguous().view(bs, x_len, -1)
#def _line_shift1(x:Tensor, mask:bool=False):
# "Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal."
# bs,n,p,nh = x.size()
# x_pad = torch.cat([x.new_zeros(bs,n,1,nh), x], dim=2)
# x_shift = x_pad.view(bs,p + 1,n,nh)[:,1:].view_as(x)
# if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,:,:,None])
# return x_shift
def _line_shift(x:Tensor, mask:bool=False):
"Shift the line i of `x` by p-i elements to the left, is `mask` puts 0s on the diagonal."
bs,nh,n,p = x.size()
x_pad = torch.cat([x.new_zeros(bs,nh,n,1), x], dim=3)
x_shift = x_pad.view(bs,nh,p + 1,n)[:,:,1:].view_as(x)
if mask: x_shift.mul_(torch.tril(x.new_ones(n,p), p-n)[None,None,])
return x_shift
class MultiHeadRelativeAttention(MultiHeadAttention):
"MutiHeadAttention with relative positional encoding."
def __init__(self, n_heads:int, d_model:int, d_head:int, resid_p:float=0., attn_p:float=0., bias:bool=True,
scale:bool=True):
super().__init__(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
self.r_attn = nn.Linear(d_model, n_heads * d_head, bias=bias)
def _apply_attention(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
#Notations from the paper: x input, r vector of relative distance between two elements, u et v learnable
#parameters of the model common between all layers, mask to avoid cheating and mem the previous hidden states.
bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
context = x if mem is None else torch.cat([mem, x], dim=1)
wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
wq = wq[:,-x_len:]
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
wkr = self.r_attn(r)
wkr = wkr.view(seq_len, self.n_heads, self.d_head)
wkr = wkr.permute(1,2,0)
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
AC = torch.matmul(wq+u,wk)
BD = _line_shift(torch.matmul(wq+v, wkr))
if self.scale: attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
if mask is not None:
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
attn_vec = torch.matmul(attn_prob, wv)
return attn_vec.permute(0, 2, 1, 3).contiguous().view(bs, x_len, -1)
def _attention_einsum(self, x:Tensor, r:Tensor=None, u:Tensor=None, v:Tensor=None, mask:Tensor=None, mem:Tensor=None):
# Permute and matmul is a little bit faster but this implementation is more readable
bs,x_len,seq_len = x.size(0),x.size(1),r.size(0)
context = x if mem is None else torch.cat([mem, x], dim=1)
wq,wk,wv = torch.chunk(self.attention(context), 3, dim=-1)
wq = wq[:,-x_len:]
wkr = self.r_attn(r)
wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
wkr = wkr.view(seq_len, self.n_heads, self.d_head)
#### compute attention score (AC is (a) + (c) and BS is (b) + (d) in the paper)
AC = torch.einsum('bind,bjnd->bijn', (wq+u, wk))
BD = _line_shift1(torch.einsum('bind,jnd->bijn', (wq+v, wkr)))
attn_score = (AC + BD).mul_(1/(self.d_head ** 0.5))
if mask is not None:
attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
attn_prob = self.drop_att(F.softmax(attn_score, dim=2))
attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))
return attn_vec.contiguous().view(bs, x_len, -1)
class DecoderLayer(Module):
"Basic block of a Transformer model."
#Can't use Sequential directly cause more than one input...
def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, resid_p:float=0., attn_p:float=0., ff_p:float=0.,
bias:bool=True, scale:bool=True, act:Activation=Activation.ReLU, double_drop:bool=True,
attn_cls:Callable=MultiHeadAttention):
self.mhra = attn_cls(n_heads, d_model, d_head, resid_p=resid_p, attn_p=attn_p, bias=bias, scale=scale)
self.ff = feed_forward(d_model, d_inner, ff_p=ff_p, act=act, double_drop=double_drop)
def forward(self, x:Tensor, mask:Tensor=None, **kwargs): return self.ff(self.mhra(x, mask=mask, **kwargs))
class Transformer(Module):
"Transformer model: https://arxiv.org/abs/1706.03762."
def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=True, scale:bool=True,
act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadAttention,
learned_pos_enc:bool=True, mask:bool=True):
self.mask = mask
self.encoder = nn.Embedding(vocab_sz, d_model)
self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
self.drop_emb = nn.Dropout(embed_p)
self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop,
attn_cls=attn_cls) for k in range(n_layers)])
def reset(self): pass
def forward(self, x):
bs, x_len = x.size()
pos = torch.arange(0, x_len, device=x.device, dtype=x.dtype)
inp = self.drop_emb(self.encoder(x) + self.pos_enc(pos)[None]) #.mul_(self.d_model ** 0.5)
mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None] if self.mask else None
#[None,:,:None] for einsum implementation of attention
for layer in self.layers: inp = layer(inp, mask=mask)
return ([inp],[inp]) #For the LinearDecoder
class TransformerXL(Module):
"TransformerXL model: https://arxiv.org/abs/1901.02860."
def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int,
resid_p:float=0., attn_p:float=0., ff_p:float=0., embed_p:float=0., bias:bool=False, scale:bool=True,
act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,
learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0):
self.encoder = nn.Embedding(vocab_sz, d_model)
self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
self.drop_emb = nn.Dropout(embed_p)
self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask
self.init = False
self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop,
attn_cls=attn_cls) for k in range(n_layers)])
def reset(self):
"Reset the internal memory."
self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]
def _update_mems(self, hids):
if not getattr(self, 'hidden', False): return None
assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
with torch.no_grad():
for i in range(len(hids)):
cat = torch.cat([self.hidden[i], hids[i]], dim=1)
self.hidden[i] = cat[:,-self.mem_len:].detach()
def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]
def forward(self, x):
#The hidden state has to be initiliazed in the forward pass for nn.DataParallel
if self.mem_len > 0 and not self.init:
self.reset()
self.init = True
bs,x_len = x.size()
inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)
m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
seq_len = m_len + x_len
mask = torch.triu(x.new_ones(x_len, seq_len), diagonal=1+m_len).byte()[None,None] if self.mask else None
#[None,:,:None] for einsum implementation of attention
hids = []
pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
pos_enc = self.pos_enc(pos)
hids.append(inp)
for i, layer in enumerate(self.layers):
mem = self.hidden[i] if self.mem_len > 0 else None
inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
hids.append(inp)
core_out = inp[:,-x_len:]
if self.mem_len > 0 : self._update_mems(hids)
return (self.hidden if self.mem_len > 0 else [core_out]),[core_out]
def init_transformer(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 0., 0.02)
if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight') and m.weight is not None: nn.init.normal_(m.weight, 1., 0.02)
if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0.)
elif classname.find('TransformerXL') != -1:
if hasattr(m, 'u'): nn.init.normal_(m.u, 0., 0.02)
if hasattr(m, 'v'): nn.init.normal_(m.v, 0., 0.02)
tfmer_lm_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1,
ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False,
tie_weights=True, out_bias=False, init=init_transformer, mask=True)
tfmer_clas_config = dict(ctx_len=512, n_layers=12, n_heads=12, d_model=768, d_head=64, d_inner=3072, resid_p=0.1, attn_p=0.1,
ff_p=0.1, embed_p=0.1, output_p=0., bias=True, scale=True, act=Activation.GeLU, double_drop=False,
init=init_transformer, mask=False)
def tfmer_lm_split(model:nn.Module) -> List[nn.Module]:
"Split a RNN `model` in groups for differential learning rates."
encoder = model[0]
n = len(encoder.layers)//3
groups = [list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])]
return groups + [[encoder.encoder, model[1]]]
def tfmer_clas_split(model:nn.Module) -> List[nn.Module]:
"Split a RNN `model` in groups for differential learning rates."
encoder = model[0].module
n = len(encoder.layers)//3
groups = [[encoder.encoder], list(encoder.layers[:n]), list(encoder.layers[n:2*n]), list(encoder.layers[2*n:])]
return groups + [[model[1]]]
tfmerXL_lm_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1,
ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True,
tie_weights=True, out_bias=True, init=init_transformer, mem_len=150, mask=True)
tfmerXL_clas_config = dict(ctx_len=150, n_layers=12, n_heads=10, d_model=410, d_head=41, d_inner=2100, resid_p=0.1, attn_p=0.1,
ff_p=0.1, embed_p=0.1, output_p=0.1, bias=False, scale=True, act=Activation.ReLU, double_drop=True,
init=init_transformer, mem_len=150, mask=False)
def tfmerXL_lm_split(model:nn.Module) -> List[nn.Module]:
"Split a RNN `model` in groups for differential learning rates."
encoder = model[0]
n = len(encoder.layers)//3
groups = [list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]]
return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [encoder.encoder, model[1]]]
def tfmerXL_clas_split(model:nn.Module) -> List[nn.Module]:
"Split a RNN `model` in groups for differential learning rates."
encoder = model[0].module
n = len(encoder.layers)//3
groups = [[encoder.encoder], list(encoder.layers[:n]) + [ParameterModule(encoder.u), ParameterModule(encoder.v)]]
return groups + [list(encoder.layers[n:2*n]), list(encoder.layers[2*n:]), [model[1]]]
================================================
FILE: fastai/text/transform.py
================================================
"NLP data processing; tokenizes text and creates vocab indexes"
from ..torch_core import *
import spacy
from spacy.symbols import ORTH
__all__ = ['BaseTokenizer', 'SpacyTokenizer', 'Tokenizer', 'Vocab', 'fix_html', 'replace_all_caps', 'replace_rep', 'replace_wrep',
'rm_useless_spaces', 'spec_add_spaces', 'BOS', 'EOS', 'FLD', 'UNK', 'PAD', 'TK_MAJ', 'TK_UP', 'TK_REP', 'TK_REP', 'TK_WREP',
'deal_caps']
BOS,EOS,FLD,UNK,PAD = 'xxbos','xxeos','xxfld','xxunk','xxpad'
TK_MAJ,TK_UP,TK_REP,TK_WREP = 'xxmaj','xxup','xxrep','xxwrep'
defaults.text_spec_tok = [UNK,PAD,BOS,EOS,FLD,TK_MAJ,TK_UP,TK_REP,TK_WREP]
class BaseTokenizer():
"Basic class for a tokenizer function."
def __init__(self, lang:str): self.lang = lang
def tokenizer(self, t:str) -> List[str]: return t.split(' ')
def add_special_cases(self, toks:Collection[str]): pass
class SpacyTokenizer(BaseTokenizer):
"Wrapper around a spacy tokenizer to make it a `BaseTokenizer`."
def __init__(self, lang:str):
self.tok = spacy.blank(lang, disable=["parser","tagger","ner"])
def tokenizer(self, t:str) -> List[str]:
return [t.text for t in self.tok.tokenizer(t)]
def add_special_cases(self, toks:Collection[str]):
for w in toks:
self.tok.tokenizer.add_special_case(w, [{ORTH: w}])
def spec_add_spaces(t:str) -> str:
"Add spaces around / and # in `t`. \n"
return re.sub(r'([/#\n])', r' \1 ', t)
def rm_useless_spaces(t:str) -> str:
"Remove multiple spaces in `t`."
return re.sub(' {2,}', ' ', t)
def replace_rep(t:str) -> str:
"Replace repetitions at the character level in `t`."
def _replace_rep(m:Collection[str]) -> str:
c,cc = m.groups()
return f' {TK_REP} {len(cc)+1} {c} '
re_rep = re.compile(r'(\S)(\1{3,})')
return re_rep.sub(_replace_rep, t)
def replace_wrep(t:str) -> str:
"Replace word repetitions in `t`."
def _replace_wrep(m:Collection[str]) -> str:
c,cc = m.groups()
return f' {TK_WREP} {len(cc.split())+1} {c} '
re_wrep = re.compile(r'(\b\w+\W+)(\1{3,})')
return re_wrep.sub(_replace_wrep, t)
def fix_html(x:str) -> str:
"List of replacements from html strings in `x`."
re1 = re.compile(r' +')
x = x.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
'
', "\n").replace('\\"', '"').replace('',UNK).replace(' @.@ ','.').replace(
' @-@ ','-').replace(' @,@ ',',').replace('\\', ' \\ ')
return re1.sub(' ', html.unescape(x))
def replace_all_caps(x:Collection[str]) -> Collection[str]:
"Replace tokens in ALL CAPS in `x` by their lower version and add `TK_UP` before."
res = []
for t in x:
if t.isupper() and len(t) > 1: res.append(TK_UP); res.append(t.lower())
else: res.append(t)
return res
def deal_caps(x:Collection[str]) -> Collection[str]:
"Replace all Capitalized tokens in `x` by their lower version and add `TK_MAJ` before."
res = []
for t in x:
if t == '': continue
if t[0].isupper() and len(t) > 1 and t[1:].islower(): res.append(TK_MAJ)
res.append(t.lower())
return res
defaults.text_pre_rules = [fix_html, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces]
defaults.text_post_rules = [replace_all_caps, deal_caps]
class Tokenizer():
"Put together rules and a tokenizer function to tokenize text with multiprocessing."
def __init__(self, tok_func:Callable=SpacyTokenizer, lang:str='en', pre_rules:ListRules=None,
post_rules:ListRules=None, special_cases:Collection[str]=None, n_cpus:int=None):
self.tok_func,self.lang,self.special_cases = tok_func,lang,special_cases
self.pre_rules = ifnone(pre_rules, defaults.text_pre_rules )
self.post_rules = ifnone(post_rules, defaults.text_post_rules)
self.special_cases = special_cases if special_cases else defaults.text_spec_tok
self.n_cpus = ifnone(n_cpus, defaults.cpus)
def __repr__(self) -> str:
res = f'Tokenizer {self.tok_func.__name__} in {self.lang} with the following rules:\n'
for rule in self.pre_rules: res += f' - {rule.__name__}\n'
for rule in self.post_rules: res += f' - {rule.__name__}\n'
return res
def process_text(self, t:str, tok:BaseTokenizer) -> List[str]:
"Process one text `t` with tokenizer `tok`."
for rule in self.pre_rules: t = rule(t)
toks = tok.tokenizer(t)
for rule in self.post_rules: toks = rule(toks)
return toks
def _process_all_1(self, texts:Collection[str]) -> List[List[str]]:
"Process a list of `texts` in one process."
tok = self.tok_func(self.lang)
if self.special_cases: tok.add_special_cases(self.special_cases)
return [self.process_text(str(t), tok) for t in texts]
def process_all(self, texts:Collection[str]) -> List[List[str]]:
"Process a list of `texts`."
if self.n_cpus <= 1: return self._process_all_1(texts)
with ProcessPoolExecutor(self.n_cpus) as e:
return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), [])
class Vocab():
"Contain the correspondence between numbers and tokens and numericalize."
def __init__(self, itos:Collection[str]):
self.itos = itos
self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)})
def numericalize(self, t:Collection[str]) -> List[int]:
"Convert a list of tokens `t` to their ids."
return [self.stoi[w] for w in t]
def textify(self, nums:Collection[int], sep=' ') -> List[str]:
"Convert a list of `nums` to their tokens."
return sep.join([self.itos[i] for i in nums]) if sep is not None else [self.itos[i] for i in nums]
def __getstate__(self):
return {'itos':self.itos}
def __setstate__(self, state:dict):
self.itos = state['itos']
self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)})
def save(self, path):
"Save `self.itos` in `path`"
pickle.dump(self.itos, open(path, 'wb'))
@classmethod
def create(cls, tokens:Tokens, max_vocab:int, min_freq:int) -> 'Vocab':
"Create a vocabulary from a set of `tokens`."
freq = Counter(p for o in tokens for p in o)
itos = [o for o,c in freq.most_common(max_vocab) if c >= min_freq]
for o in reversed(defaults.text_spec_tok):
if o in itos: itos.remove(o)
itos.insert(0, o)
itos = itos[:max_vocab]
if len(itos) < max_vocab: #Make sure vocab size is a multiple of 8 for fast mixed precision training
while len(itos)%8 !=0: itos.append('xxfake')
return cls(itos)
@classmethod
def load(cls, path):
"Load the `Vocab` contained in `path`"
itos = pickle.load(open(path, 'rb'))
return cls(itos)
================================================
FILE: fastai/torch_core.py
================================================
"Utility functions to help deal with tensors"
from .imports.torch import *
from .core import *
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
AffineMatrix = Tensor
BoolOrTensor = Union[bool,Tensor]
FloatOrTensor = Union[float,Tensor]
IntOrTensor = Union[int,Tensor]
ItemsList = Collection[Union[Tensor,ItemBase,'ItemsList',float,int]]
LambdaFunc = Callable[[Tensor],Tensor]
LayerFunc = Callable[[nn.Module],None]
ModuleList = Collection[nn.Module]
NPArray = np.ndarray
OptOptimizer = Optional[optim.Optimizer]
ParamList = Collection[nn.Parameter]
Rank0Tensor = NewType('OneEltTensor', Tensor)
SplitFunc = Callable[[nn.Module], List[nn.Module]]
SplitFuncOrIdxList = Union[Callable, Collection[ModuleList]]
TensorOrNumber = Union[Tensor,Number]
TensorOrNumList = Collection[TensorOrNumber]
TensorImage = Tensor
TensorImageSize = Tuple[int,int,int]
Tensors = Union[Tensor, Collection['Tensors']]
Weights = Dict[str,Tensor]
AffineFunc = Callable[[KWArgs], AffineMatrix]
HookFunc = Callable[[nn.Module, Tensors, Tensors], Any]
LogitTensorImage = TensorImage
LossFunction = Callable[[Tensor, Tensor], Rank0Tensor]
MetricFunc = Callable[[Tensor,Tensor],TensorOrNumber]
MetricFuncList = Collection[MetricFunc]
MetricsList = Collection[TensorOrNumber]
OptLossFunc = Optional[LossFunction]
OptMetrics = Optional[MetricsList]
OptSplitFunc = Optional[SplitFunc]
PixelFunc = Callable[[TensorImage, ArgStar, KWArgs], TensorImage]
LightingFunc = Callable[[LogitTensorImage, ArgStar, KWArgs], LogitTensorImage]
fastai_types = {
AnnealFunc:'AnnealFunc', ArgStar:'ArgStar', BatchSamples:'BatchSamples',
FilePathList:'FilePathList', Floats:'Floats', ImgLabel:'ImgLabel', ImgLabels:'ImgLabels', KeyFunc:'KeyFunc',
KWArgs:'KWArgs', ListOrItem:'ListOrItem', ListRules:'ListRules', ListSizes:'ListSizes',
NPArrayableList:'NPArrayableList', NPArrayList:'NPArrayList', NPArrayMask:'NPArrayMask', NPImage:'NPImage',
OptDataFrame:'OptDataFrame', OptListOrItem:'OptListOrItem', OptRange:'OptRange', OptStrTuple:'OptStrTuple',
OptStats:'OptStats', PathOrStr:'PathOrStr', PBar:'PBar', Point:'Point', Points:'Points', Sizes:'Sizes',
SplitArrayList:'SplitArrayList', StartOptEnd:'StartOptEnd', StrList:'StrList', Tokens:'Tokens',
OptStrList:'OptStrList', AffineMatrix:'AffineMatrix', BoolOrTensor:'BoolOrTensor', FloatOrTensor:'FloatOrTensor',
IntOrTensor:'IntOrTensor', ItemsList:'ItemsList', LambdaFunc:'LambdaFunc',
LayerFunc:'LayerFunc', ModuleList:'ModuleList', OptOptimizer:'OptOptimizer', ParamList:'ParamList',
Rank0Tensor:'Rank0Tensor', SplitFunc:'SplitFunc', SplitFuncOrIdxList:'SplitFuncOrIdxList',
TensorOrNumber:'TensorOrNumber', TensorOrNumList:'TensorOrNumList', TensorImage:'TensorImage',
TensorImageSize:'TensorImageSize', Tensors:'Tensors', Weights:'Weights', AffineFunc:'AffineFunc',
HookFunc:'HookFunc', LogitTensorImage:'LogitTensorImage', LossFunction:'LossFunction', MetricFunc:'MetricFunc',
MetricFuncList:'MetricFuncList', MetricsList:'MetricsList', OptLossFunc:'OptLossFunc', OptMetrics:'OptMetrics',
OptSplitFunc:'OptSplitFunc', PixelFunc:'PixelFunc', LightingFunc:'LightingFunc', IntsOrStrs:'IntsOrStrs',
PathLikeOrBinaryStream:'PathLikeOrBinaryStream'
}
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
bias_types = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
def is_pool_type(l:Callable): return re.search(r'Pool[123]d$', l.__class__.__name__)
no_wd_types = bn_types + (nn.LayerNorm,)
defaults.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
AdamW = partial(optim.Adam, betas=(0.9,0.99))
#Monkey-patch `torch.cuda.set_device` so that it updates `defaults.device`
_old_torch_cuda_set_device = torch.cuda.set_device
def _new_torch_cuda_set_device(device):
_old_torch_cuda_set_device(device)
defaults.device = torch.device('cuda', device) if isinstance(device, int) else device
torch.cuda.set_device = _new_torch_cuda_set_device
def tensor(x:Any, *rest)->Tensor:
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
if len(rest): x = (x,)+rest
# XXX: Pytorch bug in dataloader using num_workers>0; TODO: create repro and report
if is_listy(x) and len(x)==0: return tensor(0)
res = torch.tensor(x) if is_listy(x) else as_tensor(x)
if res.dtype is torch.int32:
warn('Tensor is int32: upgrading to int64; for better performance use int64 input')
return res.long()
return res
class Module(nn.Module, metaclass=PrePostInitMeta):
"Same as `nn.Module`, but no need for subclasses to call `super().__init__`"
def __pre_init__(self): super().__init__()
def __init__(self): pass
def np_address(x:np.ndarray)->int:
"Address of `x` in memory."
return x.__array_interface__['data'][0]
def to_detach(b:Tensors, cpu:bool=True):
"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`."
def _inner(x, cpu=True):
if not isinstance(x,Tensor): return x
x = x.detach()
return x.cpu() if cpu else x
return recurse(_inner, b, cpu=cpu)
def to_data(b:ItemsList):
"Recursively map lists of items in `b ` to their wrapped data."
return recurse(lambda x: x.data if isinstance(x,ItemBase) else x, b)
def to_cpu(b:ItemsList):
"Recursively map lists of tensors in `b ` to the cpu."
return recurse(lambda x: x.cpu() if isinstance(x,Tensor) else x, b)
def to_half(b:Collection[Tensor])->Collection[Tensor]:
"Recursively map lists of tensors in `b ` to FP16."
return recurse(lambda x: x.half() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)
def to_float(b:Collection[Tensor])->Collection[Tensor]:
"Recursively map lists of tensors in `b ` to FP16."
return recurse(lambda x: x.float() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)
def to_device(b:Tensors, device:torch.device):
"Recursively put `b` on `device`."
device = ifnone(device, defaults.device)
return recurse(lambda x: x.to(device, non_blocking=True), b)
def data_collate(batch:ItemsList)->Tensor:
"Convert `batch` items to tensor data."
return torch.utils.data.dataloader.default_collate(to_data(batch))
def requires_grad(m:nn.Module, b:Optional[bool]=None)->Optional[bool]:
"If `b` is not set return `requires_grad` of first param, else set `requires_grad` on all params as `b`"
ps = list(m.parameters())
if not ps: return None
if b is None: return ps[0].requires_grad
for p in ps: p.requires_grad=b
def trainable_params(m:nn.Module)->ParamList:
"Return list of trainable params in `m`."
res = filter(lambda p: p.requires_grad, m.parameters())
return res
def children(m:nn.Module)->ModuleList:
"Get children of `m`."
return list(m.children())
def num_children(m:nn.Module)->int:
"Get number of children modules in `m`."
return len(children(m))
def range_children(m:nn.Module)->Iterator[int]:
"Return iterator of len of children of `m`."
return range(num_children(m))
class ParameterModule(Module):
"Register a lone parameter `p` in a module."
def __init__(self, p:nn.Parameter): self.val = p
def forward(self, x): return x
def children_and_parameters(m:nn.Module):
"Return the children of `m` and its direct parameters not registered in modules."
children = list(m.children())
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])
for p in m.parameters():
if id(p) not in children_p: children.append(ParameterModule(p))
return children
def flatten_model(m:nn.Module):
if num_children(m):
mapped = map(flatten_model,children_and_parameters(m))
return sum(mapped,[])
else:
return [m]
#flatten_model = lambda m: sum(map(flatten_model,children_and_parameters(m)),[]) if num_children(m) else [m]
def first_layer(m:nn.Module)->nn.Module:
"Retrieve first layer in a module `m`."
return flatten_model(m)[0]
def last_layer(m:nn.Module)->nn.Module:
"Retrieve last layer in a module `m`."
return flatten_model(m)[-1]
def split_model_idx(model:nn.Module, idxs:Collection[int])->ModuleList:
"Split `model` according to the indexes in `idxs`."
layers = flatten_model(model)
if idxs[0] != 0: idxs = [0] + idxs
if idxs[-1] != len(layers): idxs.append(len(layers))
return [nn.Sequential(*layers[i:j]) for i,j in zip(idxs[:-1],idxs[1:])]
def split_model(model:nn.Module=None, splits:Collection[Union[nn.Module,ModuleList]]=None):
"Split `model` according to the layers in `splits`."
splits = listify(splits)
if isinstance(splits[0], nn.Module):
layers = flatten_model(model)
idxs = [layers.index(first_layer(s)) for s in splits]
return split_model_idx(model, idxs)
return [nn.Sequential(*s) for s in splits]
def get_param_groups(layer_groups:Collection[nn.Module])->List[List[nn.Parameter]]:
return [sum([list(trainable_params(c)) for c in l.children()], []) for l in layer_groups]
def split_no_wd_params(layer_groups:Collection[nn.Module])->List[List[nn.Parameter]]:
"Separate the parameters in `layer_groups` between `no_wd_types` and bias (`bias_types`) from the rest."
split_params = []
for l in layer_groups:
l1,l2 = [],[]
for c in l.children():
if isinstance(c, no_wd_types): l2 += list(trainable_params(c))
elif isinstance(c, bias_types):
bias = c.bias if hasattr(c, 'bias') else None
l1 += [p for p in trainable_params(c) if not (p is bias)]
if bias is not None: l2.append(bias)
else: l1 += list(trainable_params(c))
#Since we scan the children separately, we might get duplicates (tied weights). We need to preserve the order
#for the optimizer load of state_dict
l1,l2 = uniqueify(l1),uniqueify(l2)
split_params += [l1, l2]
return split_params
def set_bn_eval(m:nn.Module)->None:
"Set bn layers in eval mode for all recursive children of `m`."
for l in m.children():
if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
l.eval()
set_bn_eval(l)
def batch_to_half(b:Collection[Tensor])->Collection[Tensor]:
"Set the input of batch `b` to half precision."
return [to_half(b[0]), b[1]]
def bn2float(module:nn.Module)->nn.Module:
"If `module` is batchnorm don't use half precision."
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module.float()
for child in module.children(): bn2float(child)
return module
def model2half(model:nn.Module)->nn.Module:
"Convert `model` to half precision except the batchnorm layers."
return bn2float(model.half())
def init_default(m:nn.Module, func:LayerFunc=nn.init.kaiming_normal_)->nn.Module:
"Initialize `m` weights with `func` and set `bias` to 0."
if func:
if hasattr(m, 'weight'): func(m.weight)
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
return m
def cond_init(m:nn.Module, init_func:LayerFunc):
"Initialize the non-batchnorm layers of `m` with `init_func`."
if (not isinstance(m, bn_types)) and requires_grad(m): init_default(m, init_func)
def apply_leaf(m:nn.Module, f:LayerFunc):
"Apply `f` to children of `m`."
c = children(m)
if isinstance(m, nn.Module): f(m)
for l in c: apply_leaf(l,f)
def apply_init(m, init_func:LayerFunc):
"Initialize all non-batchnorm layers of `m` with `init_func`."
apply_leaf(m, partial(cond_init, init_func=init_func))
def in_channels(m:nn.Module) -> List[int]:
"Return the shape of the first weight layer in `m`."
for l in flatten_model(m):
if hasattr(l, 'weight'): return l.weight.shape[1]
raise Exception('No weight layer')
class ModelOnCPU():
"A context manager to evaluate `model` on the CPU inside."
def __init__(self, model:nn.Module): self.model = model
def __enter__(self):
self.device = one_param(self.model).device
return self.model.cpu()
def __exit__(self, type, value, traceback):
self.model = self.model.to(self.device)
class NoneReduceOnCPU():
"A context manager to evaluate `loss_func` with none reduce and weights on the CPU inside."
def __init__(self, loss_func:LossFunction):
self.loss_func,self.device,self.old_red = loss_func,None,None
def __enter__(self):
if hasattr(self.loss_func, 'weight') and self.loss_func.weight is not None:
self.device = self.loss_func.weight.device
self.loss_func.weight = self.loss_func.weight.cpu()
if hasattr(self.loss_func, 'reduction'):
self.old_red = getattr(self.loss_func, 'reduction')
setattr(self.loss_func, 'reduction', 'none')
return self.loss_func
else: return partial(self.loss_func, reduction='none')
def __exit__(self, type, value, traceback):
if self.device is not None: self.loss_func.weight = self.loss_func.weight.to(self.device)
if self.old_red is not None: setattr(self.loss_func, 'reduction', self.old_red)
def model_type(dtype):
"Return the torch type corresponding to `dtype`."
return (torch.float32 if np.issubdtype(dtype, np.floating) else
torch.int64 if np.issubdtype(dtype, np.integer)
else None)
def np2model_tensor(a):
"Tranform numpy array `a` to a tensor of the same type."
dtype = model_type(a.dtype)
res = as_tensor(a)
if not dtype: return res
return res.type(dtype)
def _pca(x, k=2):
"Compute PCA of `x` with `k` dimensions."
x = x-torch.mean(x,0)
U,S,V = torch.svd(x.t())
return torch.mm(x,U[:,:k])
torch.Tensor.pca = _pca
def trange_of(x):
"Create a tensor from `range_of(x)`."
return torch.arange(len(x))
def to_np(x):
"Convert a tensor to a numpy array."
return x.data.cpu().numpy()
# monkey patching to allow matplotlib to plot tensors
def tensor__array__(self, dtype=None):
res = to_np(self)
if dtype is None: return res
else: return res.astype(dtype, copy=False)
Tensor.__array__ = tensor__array__
Tensor.ndim = property(lambda x: len(x.shape))
def grab_idx(x,i,batch_first:bool=True):
"Grab the `i`-th batch in `x`, `batch_first` stating the batch dimension."
if batch_first: return ([o[i].cpu() for o in x] if is_listy(x) else x[i].cpu())
else: return ([o[:,i].cpu() for o in x] if is_listy(x) else x[:,i].cpu())
def logit(x:Tensor)->Tensor:
"Logit of `x`, clamped to avoid inf."
x = x.clamp(1e-7, 1-1e-7)
return -(1/x-1).log()
def logit_(x:Tensor)->Tensor:
"Inplace logit of `x`, clamped to avoid inf"
x.clamp_(1e-7, 1-1e-7)
return (x.reciprocal_().sub_(1)).log_().neg_()
def set_all_seed(seed:int)->None:
"Sets the seeds for all pseudo random generators in fastai lib"
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def uniform(low:Number, high:Number=None, size:Optional[List[int]]=None)->FloatOrTensor:
"Draw 1 or shape=`size` random floats from uniform dist: min=`low`, max=`high`."
if high is None: high=low
return random.uniform(low,high) if size is None else torch.FloatTensor(*listify(size)).uniform_(low,high)
def log_uniform(low, high, size:Optional[List[int]]=None)->FloatOrTensor:
"Draw 1 or shape=`size` random floats from uniform dist: min=log(`low`), max=log(`high`)."
res = uniform(log(low), log(high), size)
return exp(res) if size is None else res.exp_()
def rand_bool(p:float, size:Optional[List[int]]=None)->BoolOrTensor:
"Draw 1 or shape=`size` random booleans (`True` occuring with probability `p`)."
return uniform(0,1,size)IntOrTensor:
"Generate int or tensor `size` of ints between `low` and `high` (included)."
return random.randint(low,high) if size is None else torch.randint(low,high+1,size)
def one_param(m: nn.Module)->Tensor:
"Return the first parameter of `m`."
return next(m.parameters())
def try_int(o:Any)->Any:
"Try to convert `o` to int, default to `o` if not possible."
# NB: single-item rank-1 array/tensor can be converted to int, but we don't want to do this
if isinstance(o, (np.ndarray,Tensor)): return o if o.ndim else int(o)
if isinstance(o, collections.abc.Sized) or getattr(o,'__array_interface__',False): return o
try: return int(o)
except: return o
def get_model(model:nn.Module):
"Return the model maybe wrapped inside `model`."
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model
def flatten_check(out:Tensor, targ:Tensor) -> Tensor:
"Check that `out` and `targ` have the same number of elements and flatten them."
out,targ = out.contiguous().view(-1),targ.contiguous().view(-1)
assert len(out) == len(targ), f"Expected output and target to have the same number of elements but got {len(out)} and {len(targ)}."
return out,targ
#Monkey-patch nn.DataParallel.reset
def _data_parallel_reset(self):
if hasattr(self.module, 'reset'): self.module.reset()
nn.DataParallel.reset = _data_parallel_reset
def remove_module_load(state_dict):
"""create new OrderedDict that does not contain `module.`"""
new_state_dict = OrderedDict()
for k, v in state_dict.items(): new_state_dict[k[7:]] = v
return new_state_dict
def num_distrib():
"Return the number of processes in distributed training (if applicable)."
return int(os.environ.get('WORLD_SIZE', 0))
def rank_distrib():
"Return the distributed rank of this process (if applicable)."
return int(os.environ.get('RANK', 0))
def add_metrics(last_metrics:Collection[Rank0Tensor], mets:Union[Rank0Tensor, Collection[Rank0Tensor]]):
"Return a dictionary for updating `last_metrics` with `mets`."
last_metrics,mets = listify(last_metrics),listify(mets)
return {'last_metrics': last_metrics + mets}
def try_save(state:Dict, path:Path=None, file:PathLikeOrBinaryStream=None):
target = open(path/file, 'wb') if is_pathlike(file) else file
try: torch.save(state, target)
except OSError as e:
raise Exception(f"{e}\n Can't write {path/file}. Pass an absolute writable pathlib obj `fname`.")
def np_func(f):
"Convert a function taking and returning numpy arrays to one taking and returning tensors"
def _inner(*args, **kwargs):
nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]
return tensor(f(*nargs, **kwargs))
functools.update_wrapper(_inner, f)
return _inner
================================================
FILE: fastai/train.py
================================================
"Provides advanced training extensions to `fastai.basic_train`. Includes half-precision, learning rate finder, mixup, and one-cycle"
from .torch_core import *
from .callbacks import *
from .basic_data import *
from .basic_train import *
__all__ = ['BnFreeze', 'GradientClipping', 'ShowGraph', 'Interpretation', 'ClassificationInterpretation', 'MultiLabelClassificationInterpretation',
'fit_one_cycle', 'lr_find', 'one_cycle_scheduler', 'to_fp16', 'to_fp32', 'mixup', 'AccumulateScheduler']
def one_cycle_scheduler(lr_max:float, **kwargs:Any)->OneCycleScheduler:
"Instantiate a `OneCycleScheduler` with `lr_max`."
return partial(OneCycleScheduler, lr_max=lr_max, **kwargs)
def fit_one_cycle(learn:Learner, cyc_len:int, max_lr:Union[Floats,slice]=defaults.lr,
moms:Tuple[float,float]=(0.95,0.85), div_factor:float=25., pct_start:float=0.3, final_div:float=None,
wd:float=None, callbacks:Optional[CallbackList]=None, tot_epochs:int=None, start_epoch:int=None,
batch_multiplier:int=1)->None:
"Fit a model following the 1cycle policy."
max_lr = learn.lr_range(max_lr)
callbacks = listify(callbacks)
callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks, batch_multiplier=batch_multiplier)
def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None,
batch_multiplier:int=1):
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges."
start_lr = learn.lr_range(start_lr)
start_lr = np.array(start_lr) if is_listy(start_lr) else start_lr
end_lr = learn.lr_range(end_lr)
end_lr = np.array(end_lr) if is_listy(end_lr) else end_lr
cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
epochs = int(np.ceil(num_it/len(learn.data.train_dl)))
learn.fit(epochs, start_lr, callbacks=[cb], wd=wd, batch_multiplier=batch_multiplier)
def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=True, clip:float=None,
flat_master:bool=False, max_scale:float=2**24)->Learner:
"Put `learn` in FP16 precision mode."
learn.to_fp32()
learn.model = model2half(learn.model)
learn.data.add_tfm(batch_to_half)
learn.mp_cb = MixedPrecision(learn, loss_scale=loss_scale, max_noskip=max_noskip, dynamic=dynamic, clip=clip,
flat_master=flat_master, max_scale=max_scale)
learn.callbacks.append(learn.mp_cb)
return learn
def to_fp32(learn:Learner):
"Put `learn` back to FP32 precision mode."
learn.data.remove_tfm(batch_to_half)
for cb in learn.callbacks:
if isinstance(cb, MixedPrecision): learn.callbacks.remove(cb)
learn.model = learn.model.float()
return learn
def mixup(learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True) -> Learner:
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`."
learn.callback_fns.append(partial(MixUpCallback, alpha=alpha, stack_x=stack_x, stack_y=stack_y))
return learn
Learner.fit_one_cycle = fit_one_cycle
Learner.lr_find = lr_find
Learner.to_fp16 = to_fp16
Learner.to_fp32 = to_fp32
Learner.mixup = mixup
class ShowGraph(LearnerCallback):
"Update a graph of learner stats and metrics after each epoch."
def on_epoch_end(self, n_epochs:int, last_metrics:MetricsList, **kwargs)->bool:
"If we have `last_metrics` plot them in our pbar graph"
if last_metrics is not None and last_metrics[0] is not None:
rec = self.learn.recorder
iters = range_of(rec.losses)
val_iter = np.array(rec.nb_batches).cumsum()
x_bounds = (0, (n_epochs - len(rec.nb_batches)) * rec.nb_batches[-1] + len(rec.losses))
y_bounds = (0, max((max(Tensor(rec.losses)), max(Tensor(rec.val_losses)))))
rec.pbar.update_graph([(iters, rec.losses), (val_iter, rec.val_losses)], x_bounds, y_bounds)
return {}
class BnFreeze(LearnerCallback):
"Freeze moving average statistics in all non-trainable batchnorm layers."
def on_epoch_begin(self, **kwargs:Any)->None:
"Put bn layers in eval mode just after `model.train()`."
set_bn_eval(self.learn.model)
class GradientClipping(LearnerCallback):
"Gradient clipping during training."
def __init__(self, learn:Learner, clip:float = 0.):
super().__init__(learn)
self.clip = clip
def on_backward_end(self, **kwargs):
"Clip the gradient before the optimizer step."
if self.clip: nn.utils.clip_grad_norm_(self.learn.model.parameters(), self.clip)
def clip_grad(learn:Learner, clip:float=0.1)->Learner:
"Add gradient clipping of `clip` during training."
learn.callback_fns.append(partial(GradientClipping, clip=clip))
return learn
Learner.clip_grad = clip_grad
class AccumulateScheduler(LearnerCallback):
"Does accumlated step every nth step by accumulating gradients"
def __init__(self, learn:Learner, n_step:int = 1, drop_last:bool = False):
super().__init__(learn)
self.n_step,self.drop_last = n_step,drop_last
def on_train_begin(self, **kwargs):
"check if loss is reduction"
if hasattr(self.loss_func, "reduction") and (self.loss_func.reduction != "sum"):
warn("For better gradients consider 'reduction=sum'")
def on_epoch_begin(self, **kwargs):
"init samples and batches, change optimizer"
self.acc_samples, self.acc_batches = 0., 0.
def on_batch_begin(self, last_input, last_target, **kwargs):
"accumulate samples and batches"
self.acc_samples += last_input.shape[0]
self.acc_batches += 1
def on_backward_end(self, **kwargs):
"accumulated step and reset samples, True will result in no stepping"
if (self.acc_batches % self.n_step) == 0:
for p in (self.learn.model.parameters()):
if p.requires_grad: p.grad.div_(self.acc_samples)
self.acc_samples = 0
else: return {'skip_step':True, 'skip_zero':True}
def on_epoch_end(self, **kwargs):
"step the rest of the accumulated grads if not perfectly divisible"
for p in (self.learn.model.parameters()):
if p.requires_grad: p.grad.div_(self.acc_samples)
if not self.drop_last: self.learn.opt.step()
self.learn.opt.zero_grad()
class Interpretation():
"Interpretation base class, can be inherited for task specific Interpretation classes"
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
self.data,self.preds,self.y_true,self.losses,self.ds_type, self.learn = \
learn.data,preds,y_true,losses,ds_type,learn
self.ds = (self.data.train_ds if ds_type == DatasetType.Train else
self.data.test_ds if ds_type == DatasetType.Test else
self.data.valid_ds if ds_type == DatasetType.Valid else
self.data.single_ds if ds_type == DatasetType.Single else
self.data.fix_ds)
@classmethod
def from_learner(cls, learn: Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None):
"Gets preds, y_true, losses to construct base class from a learner"
preds_res = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)
return cls(learn, *preds_res)
def top_losses(self, k:int=None, largest=True):
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
# def top_scores(self, metric:Callable=None, k:int=None, largest=True):
# "`k` largest(/smallest) metric scores and indexes, defaulting to all scores (sorted by `largest`)."
# self.scores = metric(self.preds, self.y_true)
# return self.scores.topk(ifnone(k, len(self.scores)), largest=largest)
class ClassificationInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
super(ClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
self.pred_class = self.preds.argmax(dim=1)
def confusion_matrix(self, slice_size:int=1):
"Confusion matrix as an `np.ndarray`."
x=torch.arange(0,self.data.c)
if slice_size is None: cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
else:
cm = torch.zeros(self.data.c, self.data.c, dtype=x.dtype)
for i in range(0, self.y_true.shape[0], slice_size):
cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
& (self.y_true[i:i+slice_size]==x[:,None,None])).sum(2)
torch.add(cm, cm_slice, out=cm)
return to_np(cm)
def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", slice_size:int=1,
norm_dec:int=2, plot_txt:bool=True, return_fig:bool=None, **kwargs)->Optional[plt.Figure]:
"Plot the confusion matrix, with `title` and using `cmap`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix(slice_size=slice_size)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = np.arange(self.data.c)
plt.xticks(tick_marks, self.data.y.classes, rotation=90)
plt.yticks(tick_marks, self.data.y.classes, rotation=0)
if plot_txt:
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.grid(False)
if ifnone(return_fig, defaults.return_fig): return fig
def most_confused(self, min_val:int=1, slice_size:int=1)->Collection[Tuple[str,str,int]]:
"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences."
cm = self.confusion_matrix(slice_size=slice_size)
np.fill_diagonal(cm, 0)
res = [(self.data.classes[i],self.data.classes[j],cm[i,j])
for i,j in zip(*np.where(cm>=min_val))]
return sorted(res, key=itemgetter(2), reverse=True)
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid):
"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
Learner.interpret = _learner_interpret
class MultiLabelClassificationInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid,
sigmoid:bool=True, thresh:float=0.3):
raise NotImplementedError
super(MultiLabelClassificationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
self.pred_class = self.preds.sigmoid(dim=1)>thresh if sigmoid else self.preds>thresh
================================================
FILE: fastai/utils/__init__.py
================================================
from .collect_env import *
__all__ = [*collect_env.__all__]
================================================
FILE: fastai/utils/check_perf.py
================================================
from ..script import *
from .collect_env import *
# Temporary POC for module-based script
call_parse(check_perf)
================================================
FILE: fastai/utils/collect_env.py
================================================
"Utility functions to help deal with user environment"
from ..imports.torch import *
from ..core import *
from ..script import *
from .pynvml_gate import *
import fastprogress, subprocess, platform
__all__ = ['show_install', 'check_perf']
def get_env(name):
"Return env var value if it's defined and not an empty string, or return Unknown"
res = os.environ.get(name,'')
return res if len(res) else "Unknown"
def show_install(show_nvidia_smi:bool=False):
"Print user's setup information"
import platform, fastai.version
rep = []
opt_mods = []
rep.append(["=== Software ===", None])
rep.append(["python", platform.python_version()])
rep.append(["fastai", fastai.__version__])
rep.append(["fastprogress", fastprogress.__version__])
rep.append(["torch", torch.__version__])
# nvidia-smi
cmd = "nvidia-smi"
have_nvidia_smi = False
try: result = subprocess.run(cmd.split(), shell=False, check=False, stdout=subprocess.PIPE)
except: pass
else:
if result.returncode == 0 and result.stdout: have_nvidia_smi = True
# XXX: if nvidia-smi is not available, another check could be:
# /proc/driver/nvidia/version on most systems, since it's the
# currently active version
if have_nvidia_smi:
smi = result.stdout.decode('utf-8')
# matching: "Driver Version: 396.44"
match = re.findall(r'Driver Version: +(\d+\.\d+)', smi)
if match: rep.append(["nvidia driver", match[0]])
available = "available" if torch.cuda.is_available() else "**Not available** "
rep.append(["torch cuda", f"{torch.version.cuda} / is {available}"])
# no point reporting on cudnn if cuda is not available, as it
# seems to be enabled at times even on cpu-only setups
if torch.cuda.is_available():
enabled = "enabled" if torch.backends.cudnn.enabled else "**Not enabled** "
rep.append(["torch cudnn", f"{torch.backends.cudnn.version()} / is {enabled}"])
rep.append(["\n=== Hardware ===", None])
# it's possible that torch might not see what nvidia-smi sees?
gpu_total_mem = []
nvidia_gpu_cnt = 0
if have_nvidia_smi:
try:
cmd = "nvidia-smi --query-gpu=memory.total --format=csv,nounits,noheader"
result = subprocess.run(cmd.split(), shell=False, check=False, stdout=subprocess.PIPE)
except:
print("have nvidia-smi, but failed to query it")
else:
if result.returncode == 0 and result.stdout:
output = result.stdout.decode('utf-8')
gpu_total_mem = [int(x) for x in output.strip().split('\n')]
nvidia_gpu_cnt = len(gpu_total_mem)
if nvidia_gpu_cnt: rep.append(["nvidia gpus", nvidia_gpu_cnt])
torch_gpu_cnt = torch.cuda.device_count()
if torch_gpu_cnt:
rep.append(["torch devices", torch_gpu_cnt])
# information for each gpu
for i in range(torch_gpu_cnt):
rep.append([f" - gpu{i}", (f"{gpu_total_mem[i]}MB | " if gpu_total_mem else "") + torch.cuda.get_device_name(i)])
else:
if nvidia_gpu_cnt:
rep.append([f"Have {nvidia_gpu_cnt} GPU(s), but torch can't use them (check nvidia driver)", None])
else:
rep.append([f"No GPUs available", None])
rep.append(["\n=== Environment ===", None])
rep.append(["platform", platform.platform()])
if platform.system() == 'Linux':
distro = try_import('distro')
if distro:
# full distro info
rep.append(["distro", ' '.join(distro.linux_distribution())])
else:
opt_mods.append('distro');
# partial distro info
rep.append(["distro", platform.uname().version])
rep.append(["conda env", get_env('CONDA_DEFAULT_ENV')])
rep.append(["python", sys.executable])
rep.append(["sys.path", "\n".join(sys.path)])
print("\n\n```text")
keylen = max([len(e[0]) for e in rep if e[1] is not None])
for e in rep:
print(f"{e[0]:{keylen}}", (f": {e[1]}" if e[1] is not None else ""))
if have_nvidia_smi:
if show_nvidia_smi: print(f"\n{smi}")
else:
if torch_gpu_cnt: print("no nvidia-smi is found")
else: print("no supported gpus found on this system")
print("```\n")
print("Please make sure to include opening/closing ``` when you paste into forums/github to make the reports appear formatted as code sections.\n")
if opt_mods:
print("Optional package(s) to enhance the diagnostics can be installed with:")
print(f"pip install {' '.join(opt_mods)}")
print("Once installed, re-run this utility to get the additional information")
def pypi_module_version_is_available(module, version):
"Check whether module==version is available on pypi"
# returns True/False (or None if failed to execute the check)
# using a hack that when passing "module==" w/ no version number to pip
# it "fails" and returns all the available versions in stderr
try:
cmd = f"pip install {module}=="
result = subprocess.run(cmd.split(), shell=False, check=False,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except Exception as e:
print(f"Error: {e}")
return None
else:
if result.returncode == 1 and result.stderr:
output = result.stderr.decode('utf-8')
return True if version in output else False
else:
print(f"Some error in {cmd}")
return None
def check_perf():
"Suggest how to improve the setup to speed things up"
from PIL import features, Image
from packaging import version
print("Running performance checks.")
# libjpeg_turbo check
print("\n*** libjpeg-turbo status")
if version.parse(Image.PILLOW_VERSION) >= version.parse("5.3.9"):
if features.check_feature('libjpeg_turbo'):
print("✔ libjpeg-turbo is on")
else:
print("✘ libjpeg-turbo is not on. It's recommended you install libjpeg-turbo to speed up JPEG decoding. See https://docs.fast.ai/performance.html#libjpeg-turbo")
else:
print(f"❓ libjpeg-turbo's status can't be derived - need Pillow(-SIMD)? >= 5.4.0 to tell, current version {Image.PILLOW_VERSION}")
# XXX: remove this check/note once Pillow and Pillow-SIMD 5.4.0 is available
pillow_ver_5_4_is_avail = pypi_module_version_is_available("Pillow", "5.4.0")
if pillow_ver_5_4_is_avail == False:
print("5.4.0 is not yet available, other than the dev version on github, which can be installed via pip from git+https://github.com/python-pillow/Pillow. See https://docs.fast.ai/performance.html#libjpeg-turbo")
# Pillow-SIMD check
print("\n*** Pillow-SIMD status")
if re.search(r'\.post\d+', Image.PILLOW_VERSION):
print(f"✔ Running Pillow-SIMD {Image.PILLOW_VERSION}")
else:
print(f"✘ Running Pillow {Image.PILLOW_VERSION}; It's recommended you install Pillow-SIMD to speed up image resizing and other operations. See https://docs.fast.ai/performance.html#pillow-simd")
# CUDA version check
# compatibility table: k: min nvidia ver is required for v: cuda ver
# note: windows nvidia driver version is slightly higher, see:
# https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
# note: add new entries if pytorch starts supporting new cudaXX
nvidia2cuda = {
"410.00": "10.0",
"384.81": "9.0",
"367.48": "8.0",
}
print("\n*** CUDA status")
if torch.cuda.is_available():
pynvml = load_pynvml_env()
nvidia_ver = (pynvml.nvmlSystemGetDriverVersion().decode('utf-8') if platform.system() != "Darwin" else "Cannot be determined on OSX yet")
cuda_ver = torch.version.cuda
max_cuda = "8.0"
for k in sorted(nvidia2cuda.keys()):
if version.parse(nvidia_ver) > version.parse(k): max_cuda = nvidia2cuda[k]
if version.parse(str(max_cuda)) <= version.parse(cuda_ver):
print(f"✔ Running the latest CUDA {cuda_ver} with NVIDIA driver {nvidia_ver}")
else:
print(f"✘ You are running pytorch built against cuda {cuda_ver}, your NVIDIA driver {nvidia_ver} supports cuda10. See https://pytorch.org/get-started/locally/ to install pytorch built against the faster CUDA version.")
else:
print(f"❓ Running cpu-only torch version, CUDA check is not relevant")
print("\nRefer to https://docs.fast.ai/performance.html to make sense out of these checks and suggestions.")
================================================
FILE: fastai/utils/ipython.py
================================================
"ipython utils"
import os, functools, traceback, gc
def is_in_ipython():
"Is the code running in the ipython environment (jupyter including)"
program_name = os.path.basename(os.getenv('_', ''))
if ('jupyter-notebook' in program_name or # jupyter-notebook
'ipython' in program_name or # ipython
'JPY_PARENT_PID' in os.environ): # ipython-notebook
return True
else:
return False
IS_IN_IPYTHON = is_in_ipython()
def is_in_colab():
"Is the code running in Google Colaboratory?"
if not IS_IN_IPYTHON: return False
try:
from google import colab
return True
except: return False
IS_IN_COLAB = is_in_colab()
def get_ref_free_exc_info():
"Free traceback from references to locals() in each frame to avoid circular reference leading to gc.collect() unable to reclaim memory"
type, val, tb = sys.exc_info()
traceback.clear_frames(tb)
return (type, val, tb)
def gpu_mem_restore(func):
"Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
@functools.wraps(func)
def wrapper(*args, **kwargs):
tb_clear_frames = os.environ.get('FASTAI_TB_CLEAR_FRAMES', None)
if not IS_IN_IPYTHON or tb_clear_frames=="0":
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception as e:
if ("CUDA out of memory" in str(e) or
"device-side assert triggered" in str(e) or
tb_clear_frames == "1"):
type, val, tb = get_ref_free_exc_info() # must!
gc.collect()
if "device-side assert triggered" in str(e):
warn("""When 'device-side assert triggered' error happens, it's not possible to recover and you must restart the kernel to continue. Use os.environ['CUDA_LAUNCH_BLOCKING']="1" before restarting to debug""")
raise type(val).with_traceback(tb) from None
else: raise # re-raises the exact last exception
return wrapper
class gpu_mem_restore_ctx():
"context manager to reclaim RAM if an exception happened under ipython"
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_val: return True
traceback.clear_frames(exc_tb)
gc.collect()
raise exc_type(exc_val).with_traceback(exc_tb) from None
================================================
FILE: fastai/utils/mem.py
================================================
"Utility functions for memory management"
from ..imports.torch import *
from ..core import *
from ..script import *
import functools, threading, time
from .pynvml_gate import *
from collections import namedtuple
#is_osx = platform.system() == "Darwin"
use_gpu = torch.cuda.is_available()
GPUMemory = namedtuple('GPUMemory', ['total', 'free', 'used'])
if use_gpu:
pynvml = load_pynvml_env()
def preload_pytorch():
torch.ones((1, 1)).cuda()
def b2mb(num):
""" convert Bs to MBs and round down """
return int(num/2**20)
def gpu_mem_get(id=None):
"get total, used and free memory (in MBs) for gpu `id`. if `id` is not passed, currently selected torch device is used"
if not use_gpu: return GPUMemory(0, 0, 0)
if id is None: id = torch.cuda.current_device()
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(id)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return GPUMemory(*(map(b2mb, [info.total, info.free, info.used])))
except:
return GPUMemory(0, 0, 0)
def gpu_mem_get_all():
"get total, used and free memory (in MBs) for each available gpu"
if not use_gpu: return []
return list(map(gpu_mem_get, range(pynvml.nvmlDeviceGetCount())))
def gpu_mem_get_free():
"get free memory (in MBs) for the currently selected gpu id, w/o emptying the cache"
return gpu_mem_get().free
def gpu_mem_get_free_no_cache():
"get free memory (in MBs) for the currently selected gpu id, after emptying the cache"
torch.cuda.empty_cache()
return gpu_mem_get().free
def gpu_mem_get_used():
"get used memory (in MBs) for the currently selected gpu id, w/o emptying the cache"
return gpu_mem_get().used
def gpu_mem_get_used_fast(gpu_handle):
"get used memory (in MBs) for the currently selected gpu id, w/o emptying the cache, and needing the `gpu_handle` arg"
info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle)
return b2mb(info.used)
def gpu_mem_get_used_no_cache():
"get used memory (in MBs) for the currently selected gpu id, after emptying the cache"
torch.cuda.empty_cache()
return gpu_mem_get().used
def gpu_with_max_free_mem():
"get [gpu_id, its_free_ram] for the first gpu with highest available RAM"
mem_all = gpu_mem_get_all()
if not len(mem_all): return None, 0
free_all = np.array([x.free for x in mem_all])
id = np.argmax(free_all)
return id, free_all[id]
class GPUMemTrace():
"Trace allocated and peaked GPU memory usage (deltas)."
def __init__(self, silent=False, ctx=None, on_exit_report=True):
assert torch.cuda.is_available(), "pytorch CUDA is required"
self.silent = silent # shortcut to turn off all reports from constructor
self.ctx = ctx # default context note in report
self.on_exit_report = on_exit_report # auto-report on ctx manager exit (default: True)
self.start()
def reset(self):
self.used_start = gpu_mem_get_used_no_cache()
self.used_peak = self.used_start
def data_set(self):
# delta_used is the difference between current used mem and used mem at the start
self.delta_used = gpu_mem_get_used_no_cache() - self.used_start
# delta_peaked is the overhead if any. It is calculated as follows:
#
# 1. The difference between the peak memory and the used memory at the
# start is measured:
# 2a. If it's negative, then delta_peaked is 0
# 2b. Otherwise, if used_delta is positive it gets subtracted from delta_peaked
# XXX: 2a shouldn't be needed once we have a reliable peak counter
self.delta_peaked = self.used_peak - self.used_start
if self.delta_peaked < 0: self.delta_peaked = 0
elif self.delta_used > 0: self.delta_peaked -= self.delta_used
def data(self):
if self.is_running: self.data_set()
return self.delta_used, self.delta_peaked
def start(self):
self.is_running = True
self.reset()
self.peak_monitor_start()
def stop(self):
self.peak_monitor_stop()
self.data_set()
self.is_running = False
def __enter__(self):
self.start()
return self
def __exit__(self, *exc):
self.stop()
if self.on_exit_report: self.report('exit')
def __del__(self):
self.stop()
def __repr__(self):
delta_used, delta_peaked = self.data()
return f"△Used Peaked MB: {delta_used:6,.0f} {delta_peaked:6,.0f}"
def _get_ctx(self, subctx=None):
"Return ' (ctx: subctx)' or ' (ctx)' or ' (subctx)' or '' depending on this and constructor arguments"
l = []
if self.ctx is not None: l.append(self.ctx)
if subctx is not None: l.append(subctx)
return '' if len(l) == 0 else f" ({': '.join(l)})"
def silent(self, silent=True):
self.silent = silent
def report(self, subctx=None):
"Print delta used+peaked, and an optional context note, which can also be preset in constructor"
if self.silent: return
print(f"{ self.__repr__() }{ self._get_ctx(subctx) }")
def report_n_reset(self, subctx=None):
"Print delta used+peaked, and an optional context note. Then reset counters"
self.report(subctx)
self.reset()
def peak_monitor_start(self):
self.peak_monitoring = True
# continually sample GPU RAM usage
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
peak_monitor_thread.daemon = True
peak_monitor_thread.start()
def peak_monitor_stop(self):
self.peak_monitoring = False
# XXX: this is an unreliable function, since there is no thread priority
# control and it may not run enough or not run at all
def peak_monitor_func(self):
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda.current_device())
while True:
self.used_peak = max(gpu_mem_get_used_fast(gpu_handle), self.used_peak)
if not self.peak_monitoring: break
time.sleep(0.001) # 1msec
def gpu_mem_trace(func):
"A decorator that runs `GPUMemTrace` w/ report on func"
@functools.wraps(func)
def wrapper(*args, **kwargs):
with GPUMemTrace(ctx=func.__qualname__, on_exit_report=True):
return func(*args, **kwargs)
return wrapper
def reduce_mem_usage(df):
""" iterate through all the columns of a dataframe and modify the data type
to reduce memory usage.
"""
start_mem = df.memory_usage().sum() / 1024**2
print('Memory usage of dataframe is {:.2f} MB'.format(start_mem))
#Removed from debugging
columns = df.columns
#.drop('index')
for col in columns:
col_type = df[col].dtype
if str(col_type) != 'category' and col_type != 'datetime64[ns]' and col_type != bool:
if col_type != object:
c_min = df[col].min()
c_max = df[col].max()
if str(col_type)[:3] == 'int':
if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
df[col] = df[col].astype(np.int8)
elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
df[col] = df[col].astype(np.int16)
elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
df[col] = df[col].astype(np.int32)
elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
df[col] = df[col].astype(np.int64)
else:
#if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
#df[col] = df[col].astype(np.float16)
#Sometimes causes and error and had to remove
if c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
df[col] = df[col].astype(np.float32)
else:
print('Error '+col+' Value would be a float64. Disregarding.')
else:
df[col] = df[col].astype('category')
end_mem = df.memory_usage().sum() / 1024**2
print('Memory usage after optimization is: {:.2f} MB'.format(end_mem))
print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) / start_mem))
return df
================================================
FILE: fastai/utils/mod_display.py
================================================
" Utils for modifying what is displayed in notebooks and command line"
import fastai
import fastprogress
from ..basic_train import *
from ..core import *
__all__ = ['progress_disabled_ctx']
class progress_disabled_ctx():
"Context manager to disable the progress update bar and Recorder print."
def __init__(self,learn:Learner):
self.learn = learn
def __enter__(self):
#silence progress bar
fastprogress.fastprogress.NO_BAR = True
fastai.basic_train.master_bar,fastai.basic_train.progress_bar = fastprogress.force_console_behavior()
self.orig_callback_fns = copy(self.learn.callback_fns)
rec_name = [x for x in self.learn.callback_fns if hasattr(x, 'func') and x.func == Recorder]
if len(rec_name):
rec_idx = self.learn.callback_fns.index(rec_name[0])
self.learn.callback_fns[rec_idx] = partial(Recorder, add_time=True, silent=True) #silence recorder
return self.learn
def __exit__(self, *args):
fastai.basic_train.master_bar,fastai.basic_train.progress_bar = master_bar,progress_bar
self.learn.callback_fns = self.orig_callback_fns
================================================
FILE: fastai/utils/pynvml_gate.py
================================================
"""Get OS specific nvml wrapper. On OSX we use pynvx as drop in replacement for pynvml"""
import platform
from ..script import *
#
# BEGIN: Temporary workaround for nvml.dll load issue in Win10
#
# Remove once nicolargo/nvidia-ml-py3#2 and a new version of the module is released
# (OR fbcotter/py3nvml#10 but will require extra work to rename things)
# Refer https://forums.fast.ai/t/nvml-dll-loading-issue-in-nvidia-ml-py3-7-352-0-py-0/39684/8
import threading
from ctypes import *
nvmlLib = None
libLoadLock = threading.Lock()
def _LoadNvmlLibrary():
'''
Load the library if it isn't loaded already
'''
global nvmlLib
if (nvmlLib == None):
libLoadLock.acquire()
try:
if (nvmlLib == None):
try:
if (sys.platform[:3] == "win"):
searchPaths = [
os.path.join(os.getenv("ProgramFiles", r"C:\Program Files"), r"NVIDIA Corporation\NVSMI\nvml.dll"),
os.path.join(os.getenv("WinDir", r"C:\Windows"), r"System32\nvml.dll"),
]
nvmlPath = next((x for x in searchPaths if os.path.isfile(x)), None)
if (nvmlPath == None):
nvmlLib = None
else:
nvmlLib = CDLL(nvmlPath)
else:
nvmlLib = None
except OSError as ose:
nvmlLib = None
finally:
libLoadLock.release()
#
# END: Temporary workaround for nvml.dll load issue in Win10
#
def load_pynvml_env():
import pynvml # nvidia-ml-py3
#
# BEGIN: Temporary workaround for nvml.dll load issue in Win10 (continued)
_LoadNvmlLibrary()
pynvml.nvmlLib = nvmlLib
#
# END: Temporary workaround for nvml.dll load issue in Win10
#
if platform.system() == "Darwin":
try:
from pynvx import pynvml
except:
print("please install pynvx on OSX: pip install pynvx")
sys.exit(1)
pynvml.nvmlInit()
return pynvml
pynvml.nvmlInit()
return pynvml
================================================
FILE: fastai/utils/show_install.py
================================================
from ..script import *
from .collect_env import *
# Temporary POC for module-based script
@call_parse
def main(show_nvidia_smi:Param(opt=False, nargs='?', type=bool)=False):
return show_install(show_nvidia_smi)
================================================
FILE: fastai/version.py
================================================
__all__ = ['__version__']
__version__ = '1.0.56.dev0'
================================================
FILE: fastai/vision/__init__.py
================================================
from .. import basics
from ..basics import *
from .learner import *
from .image import *
from .data import *
from .transform import *
from .tta import *
from . import models
from .. import vision
__all__ = [*basics.__all__, *learner.__all__, *data.__all__, *image.__all__, *transform.__all__, *tta.__all__, 'models', 'vision']
================================================
FILE: fastai/vision/cyclegan.py
================================================
from ..torch_core import *
from ..layers import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
__all__ = ['CycleGAN', 'CycleGanLoss', 'AdaptiveLoss', 'CycleGANTrainer']
def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):
return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),
norm_layer(ch_out), nn.ReLU(True)]
def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True,
pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:
layers = []
if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))
elif pad_mode == 'border': layers.append(nn.ReplicationPad2d(pad))
p = pad if pad_mode == 'zeros' else 0
conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)
if init:
init(conv.weight)
if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
layers += [conv, norm_layer(ch_out)]
if activ: layers.append(nn.ReLU(inplace=True))
return layers
class ResnetBlock(Module):
def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):
assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)
if dropout != 0: layers.append(nn.Dropout(dropout))
layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)
self.conv_block = nn.Sequential(*layers)
def forward(self, x): return x + self.conv_block(x)
def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None,
dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
bias = (norm_layer == nn.InstanceNorm2d)
layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)
for i in range(2):
layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)
n_ftrs *= 2
layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]
for i in range(2):
layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)
n_ftrs //= 2
layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]
return nn.Sequential(*layers)
def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1,
activ:bool=True, slope:float=0.2, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:
conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)
if init:
init(conv.weight)
if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)
layers = [conv]
if norm_layer is not None: layers.append(norm_layer(ch_out))
if activ: layers.append(nn.LeakyReLU(slope, inplace=True))
return layers
def critic(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:
norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)
bias = (norm_layer == nn.InstanceNorm2d)
layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)
for i in range(n_layers-1):
new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs
layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)
n_ftrs = new_ftrs
new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs
layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)
layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))
if sigmoid: layers.append(nn.Sigmoid())
return nn.Sequential(*layers)
class CycleGAN(Module):
def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True,
drop:float=0., norm_layer:nn.Module=None):
self.D_A = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
self.D_B = critic(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)
self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)
#G_A: takes real input B and generates fake input A
#G_B: takes real input A and generates fake input B
#D_A: trained to make the difference between real input A and fake input A
#D_B: trained to make the difference between real input B and fake input B
def forward(self, real_A, real_B):
fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)
if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)
idt_A, idt_B = self.G_A(real_A), self.G_B(real_B)
return [fake_A, fake_B, idt_A, idt_B]
class AdaptiveLoss(Module):
def __init__(self, crit): self.crit = crit
def forward(self, output, target:bool):
targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())
return self.crit(output, targ)
class CycleGanLoss(Module):
def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):
self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt
#self.crit = F.mse_loss if lsgan else F.binary_cross_entropy
self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)
def set_input(self, input):
self.real_A,self.real_B = input
def forward(self, output, target):
fake_A, fake_B, idt_A, idt_B = output
#Generators should return identity on the datasets they try to convert to
idt_loss = self.l_idt * (self.l_B * F.l1_loss(idt_A, self.real_B) + self.l_A * F.l1_loss(idt_B, self.real_A))
#Generators are trained to trick the critics so the following should be ones
gen_loss = self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)
#Cycle loss
cycle_loss = self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)
cycle_loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)
self.metrics = [idt_loss, gen_loss, cycle_loss]
return idt_loss + gen_loss + cycle_loss
class CycleGANTrainer(LearnerCallback):
"`LearnerCallback` that handles cycleGAN Training."
_order=-20
def _set_trainable(self, D_A=False, D_B=False):
gen = (not D_A) and (not D_B)
requires_grad(self.learn.model.G_A, gen)
requires_grad(self.learn.model.G_B, gen)
requires_grad(self.learn.model.D_A, D_A)
requires_grad(self.learn.model.D_B, D_B)
if not gen:
self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom
self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta
self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom
self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta
def on_train_begin(self, **kwargs):
"Create the various optimizers."
self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B
self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B
self.crit = self.learn.loss_func.crit
self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])
self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])
self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])
self.learn.opt.opt = self.opt_G.opt
self._set_trainable()
self.names = ['idt_loss', 'gen_loss', 'cyc_loss', 'da_loss', 'db_loss']
self.learn.recorder.no_val=True
self.learn.recorder.add_metric_names(self.names)
self.smootheners = {n:SmoothenValue(0.98) for n in self.names}
def on_batch_begin(self, last_input, **kwargs):
"Register the `last_input` in the loss function."
self.learn.loss_func.set_input(last_input)
def on_batch_end(self, last_input, last_output, **kwargs):
"Steps through the generators then each of the critics."
self.G_A.zero_grad(); self.G_B.zero_grad()
fake_A, fake_B = last_output[0].detach(), last_output[1].detach()
real_A, real_B = last_input
self._set_trainable(D_A=True)
self.D_A.zero_grad()
loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))
loss_D_A.backward()
self.opt_D_A.step()
self._set_trainable(D_B=True)
self.D_B.zero_grad()
loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))
loss_D_B.backward()
self.opt_D_B.step()
self._set_trainable()
metrics = self.learn.loss_func.metrics + [loss_D_A, loss_D_B]
for n,m in zip(self.names,metrics): self.smootheners[n].add_value(m)
def on_epoch_end(self, last_metrics, **kwargs):
"Put the various losses in the recorder."
return add_metrics(last_metrics, [s.smooth for k,s in self.smootheners.items()])
================================================
FILE: fastai/vision/data.py
================================================
"Manages data input pipeline - folderstransformbatch input. Includes support for classification, segmentation and bounding boxes"
from numbers import Integral
from ..torch_core import *
from .image import *
from .transform import *
from ..data_block import *
from ..basic_data import *
from ..layers import *
from .learner import *
from torchvision import transforms as tvt
__all__ = ['get_image_files', 'denormalize', 'get_annotations', 'ImageDataBunch',
'ImageList', 'normalize', 'normalize_funcs', 'resize_to',
'channel_view', 'mnist_stats', 'cifar_stats', 'imagenet_stats', 'imagenet_stats_inception', 'download_images',
'verify_images', 'bb_pad_collate', 'ImageImageList', 'PointsLabelList',
'ObjectCategoryList', 'ObjectItemList', 'SegmentationLabelList', 'SegmentationItemList', 'PointsItemList']
image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/'))
def get_image_files(c:PathOrStr, check_ext:bool=True, recurse=False)->FilePathList:
"Return list of files in `c` that are images. `check_ext` will filter to `image_extensions`."
return get_files(c, extensions=(image_extensions if check_ext else None), recurse=recurse)
def get_annotations(fname, prefix=None):
"Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes."
annot_dict = json.load(open(fname))
id2images, id2bboxes, id2cats = {}, collections.defaultdict(list), collections.defaultdict(list)
classes = {}
for o in annot_dict['categories']:
classes[o['id']] = o['name']
for o in annot_dict['annotations']:
bb = o['bbox']
id2bboxes[o['image_id']].append([bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]])
id2cats[o['image_id']].append(classes[o['category_id']])
for o in annot_dict['images']:
if o['id'] in id2bboxes:
id2images[o['id']] = ifnone(prefix, '') + o['file_name']
ids = list(id2images.keys())
return [id2images[k] for k in ids], [[id2bboxes[k], id2cats[k]] for k in ids]
def bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
if isinstance(samples[0][1], int): return data_collate(samples)
max_len = max([len(s[1].data[1]) for s in samples])
bboxes = torch.zeros(len(samples), max_len, 4)
labels = torch.zeros(len(samples), max_len).long() + pad_idx
imgs = []
for i,s in enumerate(samples):
imgs.append(s[0].data[None])
bbs, lbls = s[1].data
if not (bbs.nelement() == 0):
bboxes[i,-len(lbls):] = bbs
labels[i,-len(lbls):] = tensor(lbls)
return torch.cat(imgs,0), (bboxes,labels)
def normalize(x:TensorImage, mean,std:Tensor)->TensorImage:
"Normalize `x` with `mean` and `std`."
return (x-mean[...,None,None]) / std[...,None,None]
def denormalize(x:TensorImage, mean,std:Tensor, do_x:bool=True)->TensorImage:
"Denormalize `x` with `mean` and `std`."
return x.cpu().float()*std[...,None,None] + mean[...,None,None] if do_x else x.cpu()
def _normalize_batch(b:Tuple[Tensor,Tensor], mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Tensor,Tensor]:
"`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`."
x,y = b
mean,std = mean.to(x.device),std.to(x.device)
if do_x: x = normalize(x,mean,std)
if do_y and len(y.shape) == 4: y = normalize(y,mean,std)
return x,y
def normalize_funcs(mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Callable,Callable]:
"Create normalize/denormalize func using `mean` and `std`, can specify `do_y` and `device`."
mean,std = tensor(mean),tensor(std)
return (partial(_normalize_batch, mean=mean, std=std, do_x=do_x, do_y=do_y),
partial(denormalize, mean=mean, std=std, do_x=do_x))
cifar_stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
imagenet_stats_inception = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
mnist_stats = ([0.15]*3, [0.15]*3)
def channel_view(x:Tensor)->Tensor:
"Make channel the first axis of `x` and flatten remaining axes"
return x.transpose(0,1).contiguous().view(x.shape[1],-1)
class ImageDataBunch(DataBunch):
"DataBunch suitable for computer vision."
_square_show = True
@classmethod
def create_from_ll(cls, lls:LabelLists, bs:int=64, val_bs:int=None, ds_tfms:Optional[TfmList]=None,
num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None,
test:Optional[PathOrStr]=None, collate_fn:Callable=data_collate, size:int=None, no_check:bool=False,
resize_method:ResizeMethod=None, mult:int=None, padding_mode:str='reflection',
mode:str='bilinear', tfm_y:bool=False)->'ImageDataBunch':
"Create an `ImageDataBunch` from `LabelLists` `lls` with potential `ds_tfms`."
lls = lls.transform(tfms=ds_tfms, size=size, resize_method=resize_method, mult=mult, padding_mode=padding_mode,
mode=mode, tfm_y=tfm_y)
if test is not None: lls.add_test_folder(test)
return lls.databunch(bs=bs, val_bs=val_bs, dl_tfms=dl_tfms, num_workers=num_workers, collate_fn=collate_fn,
device=device, no_check=no_check)
@classmethod
def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid',
valid_pct=None, seed:int=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch':
"Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`)."
path=Path(path)
il = ImageList.from_folder(path)
if valid_pct is None: src = il.split_by_folder(train=train, valid=valid)
else: src = il.split_by_rand_pct(valid_pct, seed)
src = src.label_from_folder(classes=classes)
return cls.create_from_ll(src, **kwargs)
@classmethod
def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2,
seed:int=None, fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch':
"Create from a `DataFrame` `df`."
src = (ImageList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)
.split_by_rand_pct(valid_pct, seed)
.label_from_df(label_delim=label_delim, cols=label_col))
return cls.create_from_ll(src, **kwargs)
@classmethod
def from_csv(cls, path:PathOrStr, folder:PathOrStr=None, label_delim:str=None, csv_labels:PathOrStr='labels.csv',
valid_pct:float=0.2, seed:int=None, fn_col:int=0, label_col:int=1, suffix:str='', delimiter:str=None,
header:Optional[Union[int,str]]='infer', **kwargs:Any)->'ImageDataBunch':
"Create from a csv file in `path/csv_labels`."
path = Path(path)
df = pd.read_csv(path/csv_labels, header=header, delimiter=delimiter)
return cls.from_df(path, df, folder=folder, label_delim=label_delim, valid_pct=valid_pct, seed=seed,
fn_col=fn_col, label_col=label_col, suffix=suffix, **kwargs)
@classmethod
def from_lists(cls, path:PathOrStr, fnames:FilePathList, labels:Collection[str], valid_pct:float=0.2, seed:int=None,
item_cls:Callable=None, **kwargs):
"Create from list of `fnames` in `path`."
item_cls = ifnone(item_cls, ImageList)
fname2label = {f:l for (f,l) in zip(fnames, labels)}
src = (item_cls(fnames, path=path).split_by_rand_pct(valid_pct, seed)
.label_from_func(lambda x:fname2label[x]))
return cls.create_from_ll(src, **kwargs)
@classmethod
def from_name_func(cls, path:PathOrStr, fnames:FilePathList, label_func:Callable, valid_pct:float=0.2, seed:int=None,
**kwargs):
"Create from list of `fnames` in `path` with `label_func`."
src = ImageList(fnames, path=path).split_by_rand_pct(valid_pct, seed)
return cls.create_from_ll(src.label_from_func(label_func), **kwargs)
@classmethod
def from_name_re(cls, path:PathOrStr, fnames:FilePathList, pat:str, valid_pct:float=0.2, **kwargs):
"Create from list of `fnames` in `path` with re expression `pat`."
pat = re.compile(pat)
def _get_label(fn):
if isinstance(fn, Path): fn = fn.as_posix()
res = pat.search(str(fn))
assert res,f'Failed to find "{pat}" in "{fn}"'
return res.group(1)
return cls.from_name_func(path, fnames, _get_label, valid_pct=valid_pct, **kwargs)
@staticmethod
def single_from_classes(path:Union[Path, str], classes:Collection[str], ds_tfms:TfmList=None, **kwargs):
"Create an empty `ImageDataBunch` in `path` with `classes`. Typically used for inference."
warn("""This method is deprecated and will be removed in a future version, use `load_learner` after
`Learner.export()`""", DeprecationWarning)
sd = ImageList([], path=path, ignore_empty=True).split_none()
return sd.label_const(0, label_cls=CategoryList, classes=classes).transform(ds_tfms, **kwargs).databunch()
def batch_stats(self, funcs:Collection[Callable]=None, ds_type:DatasetType=DatasetType.Train)->Tensor:
"Grab a batch of data and call reduction function `func` per channel"
funcs = ifnone(funcs, [torch.mean,torch.std])
x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()
return [func(channel_view(x), 1) for func in funcs]
def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None:
"Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)"
if getattr(self,'norm',False): raise Exception('Can not call normalize twice')
if stats is None: self.stats = self.batch_stats()
else: self.stats = stats
self.norm,self.denorm = normalize_funcs(*self.stats, do_x=do_x, do_y=do_y)
self.add_tfm(self.norm)
return self
def download_image(url,dest, timeout=4):
try: r = download_url(url, dest, overwrite=True, show_progress=False, timeout=timeout)
except Exception as e: print(f"Error {url} {e}")
def _download_image_inner(dest, url, i, timeout=4):
suffix = re.findall(r'\.\w+?(?=(?:\?|$))', url)
suffix = suffix[0] if len(suffix)>0 else '.jpg'
download_image(url, dest/f"{i:08d}{suffix}", timeout=timeout)
def download_images(urls:Collection[str], dest:PathOrStr, max_pics:int=1000, max_workers:int=8, timeout=4):
"Download images listed in text file `urls` to path `dest`, at most `max_pics`"
urls = open(urls).read().strip().split("\n")[:max_pics]
dest = Path(dest)
dest.mkdir(exist_ok=True)
parallel(partial(_download_image_inner, dest, timeout=timeout), urls, max_workers=max_workers)
def resize_to(img, targ_sz:int, use_min:bool=False):
"Size to resize to, to hit `targ_sz` at same aspect ratio, in PIL coords (i.e w*h)"
w,h = img.size
min_sz = (min if use_min else max)(w,h)
ratio = targ_sz/min_sz
return int(w*ratio),int(h*ratio)
def verify_image(file:Path, idx:int, delete:bool, max_size:Union[int,Tuple[int,int]]=None, dest:Path=None, n_channels:int=3,
interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None, resume:bool=False, **kwargs):
"Check if the image in `file` exists, maybe resize it and copy it in `dest`."
try:
# deal with partially broken images as indicated by PIL warnings
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
with open(file, 'rb') as img_file: PIL.Image.open(img_file)
except Warning as w:
if "Possibly corrupt EXIF data" in str(w):
if delete: # green light to modify files
print(f"{file}: Removing corrupt EXIF data")
warnings.simplefilter("ignore")
# save EXIF-cleaned up image, which happens automatically
PIL.Image.open(file).save(file)
else: # keep user's files intact
print(f"{file}: Not removing corrupt EXIF data, pass `delete=True` to do that")
else: warnings.warn(w)
img = PIL.Image.open(file)
imgarr = np.array(img)
img_channels = 1 if len(imgarr.shape) == 2 else imgarr.shape[2]
if (max_size is not None and (img.height > max_size or img.width > max_size)) or img_channels != n_channels:
assert isinstance(dest, Path), "You should provide `dest` Path to save resized image"
dest_fname = dest/file.name
if ext is not None: dest_fname=dest_fname.with_suffix(ext)
if resume and os.path.isfile(dest_fname): return
if max_size is not None:
new_sz = resize_to(img, max_size)
img = img.resize(new_sz, resample=interp)
if n_channels == 3: img = img.convert("RGB")
img.save(dest_fname, img_format, **kwargs)
except Exception as e:
print(f'{e}')
if delete: file.unlink()
def verify_images(path:PathOrStr, delete:bool=True, max_workers:int=4, max_size:Union[int]=None, recurse:bool=False,
dest:PathOrStr='.', n_channels:int=3, interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None,
resume:bool=None, **kwargs):
"Check if the images in `path` aren't broken, maybe resize them and copy it in `dest`."
path = Path(path)
if resume is None and dest == '.': resume=False
dest = path/Path(dest)
os.makedirs(dest, exist_ok=True)
files = get_image_files(path, recurse=recurse)
func = partial(verify_image, delete=delete, max_size=max_size, dest=dest, n_channels=n_channels, interp=interp,
ext=ext, img_format=img_format, resume=resume, **kwargs)
parallel(func, files, max_workers=max_workers)
class ImageList(ItemList):
"`ItemList` suitable for computer vision."
_bunch,_square_show,_square_show_res = ImageDataBunch,True,True
def __init__(self, *args, convert_mode='RGB', after_open:Callable=None, **kwargs):
super().__init__(*args, **kwargs)
self.convert_mode,self.after_open = convert_mode,after_open
self.copy_new += ['convert_mode', 'after_open']
self.c,self.sizes = 3,{}
def open(self, fn):
"Open image in `fn`, subclass and overwrite for custom behavior."
return open_image(fn, convert_mode=self.convert_mode, after_open=self.after_open)
def get(self, i):
fn = super().get(i)
res = self.open(fn)
self.sizes[i] = res.size
return res
@classmethod
def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=None, **kwargs)->ItemList:
"Get the list of files in `path` that have an image suffix. `recurse` determines if we search subfolders."
extensions = ifnone(extensions, image_extensions)
return super().from_folder(path=path, extensions=extensions, **kwargs)
@classmethod
def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='', **kwargs)->'ItemList':
"Get the filenames in `cols` of `df` with `folder` in front of them, `suffix` at the end."
suffix = suffix or ''
res = super().from_df(df, path=path, cols=cols, **kwargs)
pref = f'{res.path}{os.path.sep}'
if folder is not None: pref += f'{folder}{os.path.sep}'
res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix)
return res
@classmethod
def from_csv(cls, path:PathOrStr, csv_name:str, header:str='infer', delimiter:str=None, **kwargs)->'ItemList':
"Get the filenames in `path/csv_name` opened with `header`."
path = Path(path)
df = pd.read_csv(path/csv_name, header=header, delimiter=delimiter)
return cls.from_df(df, path=path, **kwargs)
def reconstruct(self, t:Tensor): return Image(t.float().clamp(min=0,max=1))
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`."
rows = int(np.ceil(math.sqrt(len(xs))))
axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize)
for x,y,ax in zip(xs, ys, axs.flatten()): x.show(ax=ax, y=y, **kwargs)
for ax in axs.flatten()[len(xs):]: ax.axis('off')
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
if self._square_show_res:
title = 'Ground truth\nPredictions'
rows = int(np.ceil(math.sqrt(len(xs))))
axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=12)
for x,y,z,ax in zip(xs,ys,zs,axs.flatten()): x.show(ax=ax, title=f'{str(y)}\n{str(z)}', **kwargs)
for ax in axs.flatten()[len(xs):]: ax.axis('off')
else:
title = 'Ground truth/Predictions'
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
x.show(ax=axs[i,0], y=y, **kwargs)
x.show(ax=axs[i,1], y=z, **kwargs)
class ObjectCategoryProcessor(MultiCategoryProcessor):
"`PreProcessor` for labelled bounding boxes."
def __init__(self, ds:ItemList, pad_idx:int=0):
super().__init__(ds)
self.pad_idx = pad_idx
self.state_attrs.append('pad_idx')
def process(self, ds:ItemList):
ds.pad_idx = self.pad_idx
super().process(ds)
def process_one(self,item): return [item[0], [self.c2i.get(o,None) for o in item[1]]]
def generate_classes(self, items):
"Generate classes from unique `items` and add `background`."
classes = super().generate_classes([o[1] for o in items])
classes = ['background'] + list(classes)
return classes
def _get_size(xs,i):
size = xs.sizes.get(i,None)
if size is None:
# Image hasn't been accessed yet, so we don't know its size
_ = xs[i]
size = xs.sizes[i]
return size
class ObjectCategoryList(MultiCategoryList):
"`ItemList` for labelled bounding boxes."
_processor = ObjectCategoryProcessor
def get(self, i):
return ImageBBox.create(*_get_size(self.x,i), *self.items[i], classes=self.classes, pad_idx=self.pad_idx)
def analyze_pred(self, pred): return pred
def reconstruct(self, t, x):
(bboxes, labels) = t
if len((labels - self.pad_idx).nonzero()) == 0: return
i = (labels - self.pad_idx).nonzero().min()
bboxes,labels = bboxes[i:],labels[i:]
return ImageBBox.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False)
class ObjectItemList(ImageList):
"`ItemList` suitable for object detection."
_label_cls,_square_show_res = ObjectCategoryList,False
class SegmentationProcessor(PreProcessor):
"`PreProcessor` that stores the classes for segmentation."
def __init__(self, ds:ItemList): self.classes = ds.classes
def process(self, ds:ItemList): ds.classes,ds.c = self.classes,len(self.classes)
class SegmentationLabelList(ImageList):
"`ItemList` for segmentation masks."
_processor=SegmentationProcessor
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
super().__init__(items, **kwargs)
self.copy_new.append('classes')
self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1)
def open(self, fn): return open_mask(fn)
def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None]
def reconstruct(self, t:Tensor): return ImageSegment(t)
class SegmentationItemList(ImageList):
"`ItemList` suitable for segmentation tasks."
_label_cls,_square_show_res = SegmentationLabelList,False
class PointsProcessor(PreProcessor):
"`PreProcessor` that stores the number of targets for point regression."
def __init__(self, ds:ItemList): self.c = len(ds.items[0].reshape(-1))
def process(self, ds:ItemList): ds.c = self.c
class PointsLabelList(ItemList):
"`ItemList` for points."
_processor = PointsProcessor
def __init__(self, items:Iterator, **kwargs):
super().__init__(items, **kwargs)
self.loss_func = MSELossFlat()
def get(self, i):
o = super().get(i)
return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True)
def analyze_pred(self, pred, thresh:float=0.5): return pred.view(-1,2)
def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False)
class PointsItemList(ImageList):
"`ItemList` for `Image` to `ImagePoints` tasks."
_label_cls,_square_show_res = PointsLabelList,False
class ImageImageList(ImageList):
"`ItemList` suitable for `Image` to `Image` tasks."
_label_cls,_square_show,_square_show_res = ImageList,False,False
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show the `xs` (inputs) and `ys`(targets) on a figure of `figsize`."
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
for i, (x,y) in enumerate(zip(xs,ys)):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,1], **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
title = 'Input / Prediction / Target'
axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,2], **kwargs)
z.show(ax=axs[i,1], **kwargs)
def _ll_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]):
"Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`"
self.train.x.after_open = compose(train_tfm)
self.valid.x.after_open = compose(valid_tfm)
return self
def _db_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]):
"Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`"
self.train_ds.x.after_open = compose(train_tfm)
self.valid_ds.x.after_open = compose(valid_tfm)
return self
def _presize(self, size:int, val_xtra_size:int=32, scale:Tuple[float]=(0.08, 1.0), ratio:Tuple[float]=(0.75, 4./3.),
interpolation:int=2):
"Resize images to `size` using `RandomResizedCrop`, passing along `kwargs` to train transform"
return self.pre_transform(
tvt.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation),
[tvt.Resize(size+val_xtra_size), tvt.CenterCrop(size)])
LabelLists.pre_transform = _ll_pre_transform
DataBunch.pre_transform = _db_pre_transform
LabelLists.presize = _presize
DataBunch.presize = _presize
================================================
FILE: fastai/vision/gan.py
================================================
from ..torch_core import *
from ..layers import *
from ..callback import *
from ..basic_data import *
from ..basic_train import Learner, LearnerCallback
from .image import Image
from .data import ImageList
__all__ = ['basic_critic', 'basic_generator', 'GANModule', 'GANLoss', 'GANTrainer', 'FixedGANSwitcher', 'AdaptiveGANSwitcher',
'GANLearner', 'NoisyItem', 'GANItemList', 'gan_critic', 'AdaptiveLoss', 'accuracy_thresh_expand',
'GANDiscriminativeLR']
def AvgFlatten():
"Takes the average of the input."
return Lambda(lambda x: x.mean(0).view(1))
def basic_critic(in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, **conv_kwargs):
"A basic critic for images `n_channels` x `in_size` x `in_size`."
layers = [conv_layer(n_channels, n_features, 4, 2, 1, leaky=0.2, norm_type=None, **conv_kwargs)]#norm_type=None?
cur_size, cur_ftrs = in_size//2, n_features
layers.append(nn.Sequential(*[conv_layer(cur_ftrs, cur_ftrs, 3, 1, leaky=0.2, **conv_kwargs) for _ in range(n_extra_layers)]))
while cur_size > 4:
layers.append(conv_layer(cur_ftrs, cur_ftrs*2, 4, 2, 1, leaky=0.2, **conv_kwargs))
cur_ftrs *= 2 ; cur_size //= 2
layers += [conv2d(cur_ftrs, 1, 4, padding=0), AvgFlatten()]
return nn.Sequential(*layers)
def basic_generator(in_size:int, n_channels:int, noise_sz:int=100, n_features:int=64, n_extra_layers=0, **conv_kwargs):
"A basic generator from `noise_sz` to images `n_channels` x `in_size` x `in_size`."
cur_size, cur_ftrs = 4, n_features//2
while cur_size < in_size: cur_size *= 2; cur_ftrs *= 2
layers = [conv_layer(noise_sz, cur_ftrs, 4, 1, transpose=True, **conv_kwargs)]
cur_size = 4
while cur_size < in_size // 2:
layers.append(conv_layer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **conv_kwargs))
cur_ftrs //= 2; cur_size *= 2
layers += [conv_layer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **conv_kwargs) for _ in range(n_extra_layers)]
layers += [conv2d_trans(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
return nn.Sequential(*layers)
class GANModule(Module):
"Wrapper around a `generator` and a `critic` to create a GAN."
def __init__(self, generator:nn.Module=None, critic:nn.Module=None, gen_mode:bool=False):
self.gen_mode = gen_mode
if generator: self.generator,self.critic = generator,critic
def forward(self, *args):
return self.generator(*args) if self.gen_mode else self.critic(*args)
def switch(self, gen_mode:bool=None):
"Put the model in generator mode if `gen_mode`, in critic mode otherwise."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
class GANLoss(GANModule):
"Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator)."
def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule):
super().__init__()
self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model
def generator(self, output, target):
"Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
fake_pred = self.gan_model.critic(output)
return self.loss_funcG(fake_pred, target, output)
def critic(self, real_pred, input):
"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
fake_pred = self.gan_model.critic(fake)
return self.loss_funcC(real_pred, fake_pred)
class GANTrainer(LearnerCallback):
"Handles GAN Training."
_order=-20
def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False,
show_img:bool=True):
super().__init__(learn)
self.switch_eval,self.clip,self.beta,self.gen_first,self.show_img = switch_eval,clip,beta,gen_first,show_img
self.generator,self.critic = self.model.generator,self.model.critic
def _set_trainable(self):
train_model = self.generator if self.gen_mode else self.critic
loss_model = self.generator if not self.gen_mode else self.critic
requires_grad(train_model, True)
requires_grad(loss_model, False)
if self.switch_eval:
train_model.train()
loss_model.eval()
def on_train_begin(self, **kwargs):
"Create the optimizers for the generator and critic if necessary, initialize smootheners."
if not getattr(self,'opt_gen',None):
self.opt_gen = self.opt.new([nn.Sequential(*flatten_model(self.generator))])
else: self.opt_gen.lr,self.opt_gen.wd = self.opt.lr,self.opt.wd
if not getattr(self,'opt_critic',None):
self.opt_critic = self.opt.new([nn.Sequential(*flatten_model(self.critic))])
else: self.opt_critic.lr,self.opt_critic.wd = self.opt.lr,self.opt.wd
self.gen_mode = self.gen_first
self.switch(self.gen_mode)
self.closses,self.glosses = [],[]
self.smoothenerG,self.smoothenerC = SmoothenValue(self.beta),SmoothenValue(self.beta)
#self.recorder.no_val=True
self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
self.imgs,self.titles = [],[]
def on_train_end(self, **kwargs):
"Switch in generator mode for showing results."
self.switch(gen_mode=True)
def on_batch_begin(self, last_input, last_target, **kwargs):
"Clamp the weights with `self.clip` if it's not None, return the correct input."
if self.clip is not None:
for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}
def on_backward_begin(self, last_loss, last_output, **kwargs):
"Record `last_loss` in the proper list."
last_loss = last_loss.detach().cpu()
if self.gen_mode:
self.smoothenerG.add_value(last_loss)
self.glosses.append(self.smoothenerG.smooth)
self.last_gen = last_output.detach().cpu()
else:
self.smoothenerC.add_value(last_loss)
self.closses.append(self.smoothenerC.smooth)
def on_epoch_begin(self, epoch, **kwargs):
"Put the critic or the generator back to eval if necessary."
self.switch(self.gen_mode)
def on_epoch_end(self, pbar, epoch, last_metrics, **kwargs):
"Put the various losses in the recorder and show a sample image."
if not hasattr(self, 'last_gen') or not self.show_img: return
data = self.learn.data
img = self.last_gen[0]
norm = getattr(data,'norm',False)
if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
img = data.train_ds.y.reconstruct(img)
self.imgs.append(img)
self.titles.append(f'Epoch {epoch}')
pbar.show_imgs(self.imgs, self.titles)
return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])
def switch(self, gen_mode:bool=None):
"Switch the model, if `gen_mode` is provided, in the desired mode."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
self.opt.opt = self.opt_gen.opt if self.gen_mode else self.opt_critic.opt
self._set_trainable()
self.model.switch(gen_mode)
self.loss_func.switch(gen_mode)
class FixedGANSwitcher(LearnerCallback):
"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
def __init__(self, learn:Learner, n_crit:Union[int,Callable]=1, n_gen:Union[int,Callable]=1):
super().__init__(learn)
self.n_crit,self.n_gen = n_crit,n_gen
def on_train_begin(self, **kwargs):
"Initiate the iteration counts."
self.n_c,self.n_g = 0,0
def on_batch_end(self, iteration, **kwargs):
"Switch the model if necessary."
if self.learn.gan_trainer.gen_mode:
self.n_g += 1
n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
else:
self.n_c += 1
n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
if target == n_out:
self.learn.gan_trainer.switch()
self.n_c,self.n_g = 0,0
@dataclass
class AdaptiveGANSwitcher(LearnerCallback):
"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
def __init__(self, learn:Learner, gen_thresh:float=None, critic_thresh:float=None):
super().__init__(learn)
self.gen_thresh,self.critic_thresh = gen_thresh,critic_thresh
def on_batch_end(self, last_loss, **kwargs):
"Switch the model if necessary."
if self.gan_trainer.gen_mode:
if self.gen_thresh is None: self.gan_trainer.switch()
elif last_loss < self.gen_thresh: self.gan_trainer.switch()
else:
if self.critic_thresh is None: self.gan_trainer.switch()
elif last_loss < self.critic_thresh: self.gan_trainer.switch()
def gan_loss_from_func(loss_gen, loss_crit, weights_gen:Tuple[float,float]=None):
"Define loss functions for a GAN from `loss_gen` and `loss_crit`."
def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
ones = fake_pred.new_ones(fake_pred.shape[0])
weights_gen = ifnone(weights_gen, (1.,1.))
return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
def _loss_C(real_pred, fake_pred):
ones = real_pred.new_ones (real_pred.shape[0])
zeros = fake_pred.new_zeros(fake_pred.shape[0])
return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
return _loss_G, _loss_C
class GANLearner(Learner):
"A `Learner` suitable for GANs."
def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,
crit_loss_func:LossFunction, switcher:Callback=None, gen_first:bool=False, switch_eval:bool=True,
show_img:bool=True, clip:float=None, **learn_kwargs):
gan = GANModule(generator, critic)
loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
switcher = ifnone(switcher, partial(FixedGANSwitcher, n_crit=5, n_gen=1))
super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)
trainer = GANTrainer(self, clip=clip, switch_eval=switch_eval, show_img=show_img)
self.gan_trainer = trainer
self.callbacks.append(trainer)
@classmethod
def from_learners(cls, learn_gen:Learner, learn_crit:Learner, switcher:Callback=None,
weights_gen:Tuple[float,float]=None, **learn_kwargs):
"Create a GAN from `learn_gen` and `learn_crit`."
losses = gan_loss_from_func(learn_gen.loss_func, learn_crit.loss_func, weights_gen=weights_gen)
return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher=switcher, **learn_kwargs)
@classmethod
def wgan(cls, data:DataBunch, generator:nn.Module, critic:nn.Module, switcher:Callback=None, clip:float=0.01, **learn_kwargs):
"Create a WGAN from `data`, `generator` and `critic`."
return cls(data, generator, critic, NoopLoss(), WassersteinLoss(), switcher=switcher, clip=clip, **learn_kwargs)
class NoisyItem(ItemBase):
"An random `ItemBase` of size `noise_sz`."
def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1)
def __str__(self): return ''
def apply_tfms(self, tfms, **kwargs): return self
class GANItemList(ImageList):
"`ItemList` suitable for GANs."
_label_cls = ImageList
def __init__(self, items, noise_sz:int=100, **kwargs):
super().__init__(items, **kwargs)
self.noise_sz = noise_sz
self.copy_new.append('noise_sz')
def get(self, i): return NoisyItem(self.noise_sz)
def reconstruct(self, t): return NoisyItem(t.size(0))
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Shows `ys` (target images) on a figure of `figsize`."
super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs)
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Shows `zs` (generated images) on a figure of `figsize`."
super().show_xys(zs, xs, imgsize=imgsize, figsize=figsize, **kwargs)
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
def _conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
def gan_critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:int=0.15):
"Critic to train a `GAN`."
layers = [
_conv(n_channels, nf, ks=4, stride=2),
nn.Dropout2d(p/2),
res_block(nf, dense=True,**_conv_args)]
nf *= 2 # after dense block
for i in range(n_blocks):
layers += [
nn.Dropout2d(p),
_conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
nf *= 2
layers += [
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
Flatten()]
return nn.Sequential(*layers)
class GANDiscriminativeLR(LearnerCallback):
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
def __init__(self, learn:Learner, mult_lr:float = 5.):
super().__init__(learn)
self.mult_lr = mult_lr
def on_batch_begin(self, train, **kwargs):
"Multiply the current lr if necessary."
if not self.learn.gan_trainer.gen_mode and train: self.learn.opt.lr *= self.mult_lr
def on_step_end(self, **kwargs):
"Put the LR back to its value if necessary."
if not self.learn.gan_trainer.gen_mode: self.learn.opt.lr /= self.mult_lr
class AdaptiveLoss(Module):
"Expand the `target` to match the `output` size before applying `crit`."
def __init__(self, crit):
self.crit = crit
def forward(self, output, target):
return self.crit(output, target[:,None].expand_as(output).float())
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
"Compute accuracy after expanding `y_true` to the size of `y_pred`."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh)==y_true[:,None].expand_as(y_pred).byte()).float().mean()
================================================
FILE: fastai/vision/image.py
================================================
"`Image` provides support to convert, transform and show images"
from ..torch_core import *
from ..basic_data import *
from ..layers import MSELossFlat
from io import BytesIO
import PIL
__all__ = ['PIL', 'Image', 'ImageBBox', 'ImageSegment', 'ImagePoints', 'FlowField', 'RandTransform', 'TfmAffine', 'TfmCoord',
'TfmCrop', 'TfmLighting', 'TfmPixel', 'Transform', 'bb2hw', 'image2np', 'open_image', 'open_mask', 'tis2hw',
'pil2tensor', 'scale_flow', 'show_image', 'CoordFunc', 'TfmList', 'open_mask_rle', 'rle_encode',
'rle_decode', 'ResizeMethod', 'plot_flat', 'plot_multi', 'show_multi', 'show_all']
ResizeMethod = IntEnum('ResizeMethod', 'CROP PAD SQUISH NO')
def pil2tensor(image:Union[NPImage,NPArray],dtype:np.dtype)->TensorImage:
"Convert PIL style `image` array to torch style image tensor."
a = np.asarray(image)
if a.ndim==2 : a = np.expand_dims(a,2)
a = np.transpose(a, (1, 0, 2))
a = np.transpose(a, (2, 1, 0))
return torch.from_numpy(a.astype(dtype, copy=False) )
def image2np(image:Tensor)->np.ndarray:
"Convert from torch style `image` to numpy/matplotlib style."
res = image.cpu().permute(1,2,0).numpy()
return res[...,0] if res.shape[2]==1 else res
def bb2hw(a:Collection[int])->np.ndarray:
"Convert bounding box points from (width,height,center) to (height,width,top,left)."
return np.array([a[1],a[0],a[3]-a[1],a[2]-a[0]])
def tis2hw(size:Union[int,TensorImageSize]) -> Tuple[int,int]:
"Convert `int` or `TensorImageSize` to (height,width) of an image."
if type(size) is str: raise RuntimeError("Expected size to be an int or a tuple, got a string.")
return listify(size, 2) if isinstance(size, int) else listify(size[-2:],2)
def _draw_outline(o:Patch, lw:int):
"Outline bounding box onto image `Patch`."
o.set_path_effects([patheffects.Stroke(
linewidth=lw, foreground='black'), patheffects.Normal()])
def _draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):
"Draw bounding box on `ax`."
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
_draw_outline(patch, 4)
if text is not None:
patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
_draw_outline(patch,1)
def _get_default_args(func:Callable):
return {k: v.default
for k, v in inspect.signature(func).parameters.items()
if v.default is not inspect.Parameter.empty}
@dataclass
class FlowField():
"Wrap together some coords `flow` with a `size`."
size:Tuple[int,int]
flow:Tensor
CoordFunc = Callable[[FlowField, ArgStar, KWArgs], LogitTensorImage]
class Image(ItemBase):
"Support applying transforms to image data in `px`."
def __init__(self, px:Tensor):
self._px = px
self._logit_px=None
self._flow=None
self._affine_mat=None
self.sample_kwargs = {}
def set_sample(self, **kwargs)->'ImageBase':
"Set parameters that control how we `grid_sample` the image after transforms are applied."
self.sample_kwargs = kwargs
return self
def clone(self):
"Mimic the behavior of torch.clone for `Image` objects."
return self.__class__(self.px.clone())
@property
def shape(self)->Tuple[int,int,int]: return self._px.shape
@property
def size(self)->Tuple[int,int]: return self.shape[-2:]
@property
def device(self)->torch.device: return self._px.device
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.shape)}'
def _repr_png_(self): return self._repr_image_format('png')
def _repr_jpeg_(self): return self._repr_image_format('jpeg')
def _repr_image_format(self, format_str):
with BytesIO() as str_buffer:
plt.imsave(str_buffer, image2np(self.px), format=format_str)
return str_buffer.getvalue()
def apply_tfms(self, tfms:TfmList, do_resolve:bool=True, xtra:Optional[Dict[Callable,dict]]=None,
size:Optional[Union[int,TensorImageSize]]=None, resize_method:ResizeMethod=None,
mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True,
is_x:bool=True, x_frames:int=1, y_frames:int=1)->TensorImage:
"Apply all `tfms` to the `Image`, if `do_resolve` picks value for random args."
if not (tfms or xtra or size): return self
if size is not None and isinstance(size, int):
num_frames = x_frames if is_x else y_frames
if num_frames > 1:
size = (size, size*num_frames)
tfms = listify(tfms)
xtra = ifnone(xtra, {})
default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP
resize_method = ifnone(resize_method, default_rsz)
if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms)
tfms = sorted(tfms, key=lambda o: o.tfm.order)
if do_resolve: _resolve_tfms(tfms)
x = self.clone()
x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out)
if size is not None:
crop_target = _get_crop_target(size, mult=mult)
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
target = _get_resize_target(x, crop_target, do_crop=(resize_method==ResizeMethod.CROP))
x.resize(target)
elif resize_method==ResizeMethod.SQUISH: x.resize((x.shape[0],) + crop_target)
else: size = x.size
size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)]
for tfm in tfms:
if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])
elif tfm in size_tfms:
if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode)
else: x = tfm(x)
return x.refresh()
def refresh(self)->None:
"Apply any logit, flow, or affine transfers that have been sent to the `Image`."
if self._logit_px is not None:
self._px = self._logit_px.sigmoid_()
self._logit_px = None
if self._affine_mat is not None or self._flow is not None:
self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs)
self.sample_kwargs = {}
self._flow = None
return self
def save(self, fn:PathOrStr):
"Save the image to `fn`."
x = image2np(self.data*255).astype(np.uint8)
PIL.Image.fromarray(x).save(fn)
@property
def px(self)->TensorImage:
"Get the tensor pixel buffer."
self.refresh()
return self._px
@px.setter
def px(self,v:TensorImage)->None:
"Set the pixel buffer to `v`."
self._px=v
@property
def flow(self)->FlowField:
"Access the flow-field grid after applying queued affine transforms."
if self._flow is None:
self._flow = _affine_grid(self.shape)
if self._affine_mat is not None:
self._flow = _affine_mult(self._flow,self._affine_mat)
self._affine_mat = None
return self._flow
@flow.setter
def flow(self,v:FlowField): self._flow=v
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any):
"Equivalent to `image = sigmoid(func(logit(image)))`."
self.logit_px = func(self.logit_px, *args, **kwargs)
return self
def pixel(self, func:PixelFunc, *args, **kwargs)->'Image':
"Equivalent to `image.px = func(image.px)`."
self.px = func(self.px, *args, **kwargs)
return self
def coord(self, func:CoordFunc, *args, **kwargs)->'Image':
"Equivalent to `image.flow = func(image.flow, image.size)`."
self.flow = func(self.flow, *args, **kwargs)
return self
def affine(self, func:AffineFunc, *args, **kwargs)->'Image':
"Equivalent to `image.affine_mat = image.affine_mat @ func()`."
m = tensor(func(*args, **kwargs)).to(self.device)
self.affine_mat = self.affine_mat @ m
return self
def resize(self, size:Union[int,TensorImageSize])->'Image':
"Resize the image to `size`, size can be a single int."
assert self._flow is None
if isinstance(size, int): size=(self.shape[0], size, size)
if tuple(size)==tuple(self.shape): return self
self.flow = _affine_grid(size)
return self
@property
def affine_mat(self)->AffineMatrix:
"Get the affine matrix that will be applied by `refresh`."
if self._affine_mat is None:
self._affine_mat = torch.eye(3).to(self.device)
return self._affine_mat
@affine_mat.setter
def affine_mat(self,v)->None: self._affine_mat=v
@property
def logit_px(self)->LogitTensorImage:
"Get logit(image.px)."
if self._logit_px is None: self._logit_px = logit_(self.px)
return self._logit_px
@logit_px.setter
def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v
@property
def data(self)->TensorImage:
"Return this images pixels as a tensor."
return self.px
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
cmap:str=None, y:Any=None, **kwargs):
"Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
cmap = ifnone(cmap, defaults.cmap)
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
if y is not None: y.show(ax=ax, **kwargs)
if title is not None: ax.set_title(title)
class ImageSegment(Image):
"Support applying transforms to segmentation masks data in `px`."
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'Image': return self
def refresh(self):
self.sample_kwargs['mode'] = 'nearest'
return super().refresh()
@property
def data(self)->TensorImage:
"Return this image pixels as a `LongTensor`."
return self.px.long()
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
cmap:str='tab20', alpha:float=0.5, **kwargs):
"Show the `ImageSegment` on `ax`."
ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize,
interpolation='nearest', alpha=alpha, vmin=0, **kwargs)
if title: ax.set_title(title)
def reconstruct(self, t:Tensor): return ImageSegment(t)
class ImagePoints(Image):
"Support applying transforms to a `flow` of points."
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True):
if scale: flow = scale_flow(flow)
if y_first: flow.flow = flow.flow.flip(1)
self._flow = flow
self._affine_mat = None
self.flow_func = []
self.sample_kwargs = {}
self.transformed = False
self.loss_func = MSELossFlat()
def clone(self):
"Mimic the behavior of torch.clone for `ImagePoints` objects."
return self.__class__(FlowField(self.size, self.flow.flow.clone()), scale=False, y_first=False)
@property
def shape(self)->Tuple[int,int,int]: return (1, *self._flow.size)
@property
def size(self)->Tuple[int,int]: return self._flow.size
@size.setter
def size(self, sz:int): self._flow.size=sz
@property
def device(self)->torch.device: return self._flow.flow.device
def __repr__(self): return f'{self.__class__.__name__} {tuple(self.size)}'
def _repr_image_format(self, format_str): return None
@property
def flow(self)->FlowField:
"Access the flow-field grid after applying queued affine and coord transforms."
if self._affine_mat is not None:
self._flow = _affine_inv_mult(self._flow, self._affine_mat)
self._affine_mat = None
self.transformed = True
if len(self.flow_func) != 0:
for f in self.flow_func[::-1]: self._flow = f(self._flow)
self.transformed = True
self.flow_func = []
return self._flow
@flow.setter
def flow(self,v:FlowField): self._flow=v
def coord(self, func:CoordFunc, *args, **kwargs)->'ImagePoints':
"Put `func` with `args` and `kwargs` in `self.flow_func` for later."
if 'invert' in kwargs: kwargs['invert'] = True
else: warn(f"{func.__name__} isn't implemented for {self.__class__}.")
self.flow_func.append(partial(func, *args, **kwargs))
return self
def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any)->'ImagePoints': return self
def pixel(self, func:PixelFunc, *args, **kwargs)->'ImagePoints':
"Equivalent to `self = func_flow(self)`."
self = func(self, *args, **kwargs)
self.transformed=True
return self
def refresh(self) -> 'ImagePoints':
return self
def resize(self, size:Union[int,TensorImageSize]) -> 'ImagePoints':
"Resize the image to `size`, size can be a single int."
if isinstance(size, int): size=(1, size, size)
self._flow.size = size[1:]
return self
@property
def data(self)->Tensor:
"Return the points associated to this object."
flow = self.flow #This updates flow before we test if some transforms happened
if self.transformed:
if 'remove_out' not in self.sample_kwargs or self.sample_kwargs['remove_out']:
flow = _remove_points_out(flow)
self.transformed=False
return flow.flow.flip(1)
def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True, **kwargs):
"Show the `ImagePoints` on `ax`."
if ax is None: _,ax = plt.subplots(figsize=figsize)
pnt = scale_flow(FlowField(self.size, self.data), to_unit=False).flow.flip(1)
params = {'s': 10, 'marker': '.', 'c': 'r', **kwargs}
ax.scatter(pnt[:, 0], pnt[:, 1], **params)
if hide_axis: ax.axis('off')
if title: ax.set_title(title)
class ImageBBox(ImagePoints):
"Support applying transforms to a `flow` of bounding boxes."
def __init__(self, flow:FlowField, scale:bool=True, y_first:bool=True, labels:Collection=None,
classes:dict=None, pad_idx:int=0):
super().__init__(flow, scale, y_first)
self.pad_idx = pad_idx
if labels is not None and len(labels)>0 and not isinstance(labels[0],Category):
labels = array([Category(l,classes[l]) for l in labels])
self.labels = labels
def clone(self) -> 'ImageBBox':
"Mimic the behavior of torch.clone for `Image` objects."
flow = FlowField(self.size, self.flow.flow.clone())
return self.__class__(flow, scale=False, y_first=False, labels=self.labels, pad_idx=self.pad_idx)
@classmethod
def create(cls, h:int, w:int, bboxes:Collection[Collection[int]], labels:Collection=None, classes:dict=None,
pad_idx:int=0, scale:bool=True)->'ImageBBox':
"Create an ImageBBox object from `bboxes`."
if isinstance(bboxes, np.ndarray) and bboxes.dtype == np.object: bboxes = np.array([bb for bb in bboxes])
bboxes = tensor(bboxes).float()
tr_corners = torch.cat([bboxes[:,0][:,None], bboxes[:,3][:,None]], 1)
bl_corners = bboxes[:,1:3].flip(1)
bboxes = torch.cat([bboxes[:,:2], tr_corners, bl_corners, bboxes[:,2:]], 1)
flow = FlowField((h,w), bboxes.view(-1,2))
return cls(flow, labels=labels, classes=classes, pad_idx=pad_idx, y_first=True, scale=scale)
def _compute_boxes(self) -> Tuple[LongTensor, LongTensor]:
bboxes = self.flow.flow.flip(1).view(-1, 4, 2).contiguous().clamp(min=-1, max=1)
mins, maxes = bboxes.min(dim=1)[0], bboxes.max(dim=1)[0]
bboxes = torch.cat([mins, maxes], 1)
mask = (bboxes[:,2]-bboxes[:,0] > 0) * (bboxes[:,3]-bboxes[:,1] > 0)
if len(mask) == 0: return tensor([self.pad_idx] * 4), tensor([self.pad_idx])
res = bboxes[mask]
if self.labels is None: return res,None
return res, self.labels[to_np(mask).astype(bool)]
@property
def data(self)->Union[FloatTensor, Tuple[FloatTensor,LongTensor]]:
bboxes,lbls = self._compute_boxes()
lbls = np.array([o.data for o in lbls]) if lbls is not None else None
return bboxes if lbls is None else (bboxes, lbls)
def show(self, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
color:str='white', **kwargs):
"Show the `ImageBBox` on `ax`."
if ax is None: _,ax = plt.subplots(figsize=figsize)
bboxes, lbls = self._compute_boxes()
h,w = self.flow.size
bboxes.add_(1).mul_(torch.tensor([h/2, w/2, h/2, w/2])).long()
for i, bbox in enumerate(bboxes):
if lbls is not None: text = str(lbls[i])
else: text=None
_draw_rect(ax, bb2hw(bbox), text=text, color=color)
def open_image(fn:PathOrStr, div:bool=True, convert_mode:str='RGB', cls:type=Image,
after_open:Callable=None)->Image:
"Return `Image` object created from image in file `fn`."
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
x = PIL.Image.open(fn).convert(convert_mode)
if after_open: x = after_open(x)
x = pil2tensor(x,np.float32)
if div: x.div_(255)
return cls(x)
def open_mask(fn:PathOrStr, div=False, convert_mode='L', after_open:Callable=None)->ImageSegment:
"Return `ImageSegment` object create from mask in file `fn`. If `div`, divides pixel values by 255."
return open_image(fn, div=div, convert_mode=convert_mode, cls=ImageSegment, after_open=after_open)
def open_mask_rle(mask_rle:str, shape:Tuple[int, int])->ImageSegment:
"Return `ImageSegment` object create from run-length encoded string in `mask_lre` with size in `shape`."
x = FloatTensor(rle_decode(str(mask_rle), shape).astype(np.uint8))
x = x.view(shape[1], shape[0], -1)
return ImageSegment(x.permute(2,0,1))
def rle_encode(img:NPArrayMask)->str:
"Return run-length encoding string from `img`."
pixels = np.concatenate([[0], img.flatten() , [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
def rle_decode(mask_rle:str, shape:Tuple[int,int])->NPArrayMask:
"Return an image array from run-length encoded string `mask_rle` with `shape`."
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0]*shape[1], dtype=np.uint)
for low, up in zip(starts, ends): img[low:up] = 1
return img.reshape(shape)
def show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary',
alpha:float=None, **kwargs)->plt.Axes:
"Display `Image` in notebook."
if ax is None: fig,ax = plt.subplots(figsize=figsize)
ax.imshow(image2np(img.data), cmap=cmap, alpha=alpha, **kwargs)
if hide_axis: ax.axis('off')
return ax
def scale_flow(flow, to_unit=True):
"Scale the coords in `flow` to -1/1 or the image size depending on `to_unit`."
s = tensor([flow.size[0]/2,flow.size[1]/2])[None]
if to_unit: flow.flow = flow.flow/s-1
else: flow.flow = (flow.flow+1)*s
return flow
def _remove_points_out(flow:FlowField):
pad_mask = (flow.flow[:,0] >= -1) * (flow.flow[:,0] <= 1) * (flow.flow[:,1] >= -1) * (flow.flow[:,1] <= 1)
flow.flow = flow.flow[pad_mask]
return flow
class Transform():
"Utility class for adding probability and wrapping support to transform `func`."
_wrap=None
order=0
def __init__(self, func:Callable, order:Optional[int]=None):
"Create a transform for `func` and assign it an priority `order`, attach to `Image` class."
if order is not None: self.order=order
self.func=func
self.func.__name__ = func.__name__[1:] #To remove the _ that begins every transform function.
functools.update_wrapper(self, self.func)
self.func.__annotations__['return'] = Image
self.params = copy(func.__annotations__)
self.def_args = _get_default_args(func)
setattr(Image, func.__name__,
lambda x, *args, **kwargs: self.calc(x, *args, **kwargs))
def __call__(self, *args:Any, p:float=1., is_random:bool=True, use_on_y:bool=True, **kwargs:Any)->Image:
"Calc now if `args` passed; else create a transform called prob `p` if `random`."
if args: return self.calc(*args, **kwargs)
else: return RandTransform(self, kwargs=kwargs, is_random=is_random, use_on_y=use_on_y, p=p)
def calc(self, x:Image, *args:Any, **kwargs:Any)->Image:
"Apply to image `x`, wrapping it if necessary."
if self._wrap: return getattr(x, self._wrap)(self.func, *args, **kwargs)
else: return self.func(x, *args, **kwargs)
@property
def name(self)->str: return self.__class__.__name__
def __repr__(self)->str: return f'{self.name} ({self.func.__name__})'
@dataclass
class RandTransform():
"Wrap `Transform` to add randomized execution."
tfm:Transform
kwargs:dict
p:float=1.0
resolved:dict = field(default_factory=dict)
do_run:bool = True
is_random:bool = True
use_on_y:bool = True
def __post_init__(self): functools.update_wrapper(self, self.tfm)
def resolve(self)->None:
"Bind any random variables in the transform."
if not self.is_random:
self.resolved = {**self.tfm.def_args, **self.kwargs}
return
self.resolved = {}
# for each param passed to tfm...
for k,v in self.kwargs.items():
# ...if it's annotated, call that fn...
if k in self.tfm.params:
rand_func = self.tfm.params[k]
self.resolved[k] = rand_func(*listify(v))
# ...otherwise use the value directly
else: self.resolved[k] = v
# use defaults for any args not filled in yet
for k,v in self.tfm.def_args.items():
if k not in self.resolved: self.resolved[k]=v
# anything left over must be callable without params
for k,v in self.tfm.params.items():
if k not in self.resolved and k!='return': self.resolved[k]=v()
self.do_run = rand_bool(self.p)
@property
def order(self)->int: return self.tfm.order
def __call__(self, x:Image, *args, **kwargs)->Image:
"Randomly execute our tfm on `x`."
return self.tfm(x, *args, **{**self.resolved, **kwargs}) if self.do_run else x
def _resolve_tfms(tfms:TfmList):
"Resolve every tfm in `tfms`."
for f in listify(tfms): f.resolve()
def _grid_sample(x:TensorImage, coords:FlowField, mode:str='bilinear', padding_mode:str='reflection', remove_out:bool=True)->TensorImage:
"Resample pixels in `coords` from `x` by `mode`, with `padding_mode` in ('reflection','border','zeros')."
coords = coords.flow.permute(0, 3, 1, 2).contiguous().permute(0, 2, 3, 1) # optimize layout for grid_sample
if mode=='bilinear': # hack to get smoother downwards resampling
mn,mx = coords.min(),coords.max()
# max amount we're affine zooming by (>1 means zooming in)
z = 1/(mx-mn).item()*2
# amount we're resizing by, with 100% extra margin
d = min(x.shape[1]/coords.shape[1], x.shape[2]/coords.shape[2])/2
# If we're resizing up by >200%, and we're zooming less than that, interpolate first
if d>1 and d>z: x = F.interpolate(x[None], scale_factor=1/d, mode='area')[0]
return F.grid_sample(x[None], coords, mode=mode, padding_mode=padding_mode)[0]
def _affine_grid(size:TensorImageSize)->FlowField:
size = ((1,)+size)
N, C, H, W = size
grid = FloatTensor(N, H, W, 2)
linear_points = torch.linspace(-1, 1, W) if W > 1 else tensor([-1])
grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, :, 0])
linear_points = torch.linspace(-1, 1, H) if H > 1 else tensor([-1])
grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, :, 1])
return FlowField(size[2:], grid)
def _affine_mult(c:FlowField,m:AffineMatrix)->FlowField:
"Multiply `c` by `m` - can adjust for rectangular shaped `c`."
if m is None: return c
size = c.flow.size()
h,w = c.size
m[0,1] *= h/w
m[1,0] *= w/h
c.flow = c.flow.view(-1,2)
c.flow = torch.addmm(m[:2,2], c.flow, m[:2,:2].t()).view(size)
return c
def _affine_inv_mult(c, m):
"Applies the inverse affine transform described in `m` to `c`."
size = c.flow.size()
h,w = c.size
m[0,1] *= h/w
m[1,0] *= w/h
c.flow = c.flow.view(-1,2)
a = torch.inverse(m[:2,:2].t())
c.flow = torch.mm(c.flow - m[:2,2], a).view(size)
return c
class TfmAffine(Transform):
"Decorator for affine tfm funcs."
order,_wrap = 5,'affine'
class TfmPixel(Transform):
"Decorator for pixel tfm funcs."
order,_wrap = 10,'pixel'
class TfmCoord(Transform):
"Decorator for coord tfm funcs."
order,_wrap = 4,'coord'
class TfmCrop(TfmPixel):
"Decorator for crop tfm funcs."
order=99
class TfmLighting(Transform):
"Decorator for lighting tfm funcs."
order,_wrap = 8,'lighting'
def _round_multiple(x:int, mult:int=None)->int:
"Calc `x` to nearest multiple of `mult`."
return (int(x/mult+0.5)*mult) if mult is not None else x
def _get_crop_target(target_px:Union[int,TensorImageSize], mult:int=None)->Tuple[int,int]:
"Calc crop shape of `target_px` to nearest multiple of `mult`."
target_r,target_c = tis2hw(target_px)
return _round_multiple(target_r,mult),_round_multiple(target_c,mult)
def _get_resize_target(img, crop_target, do_crop=False)->TensorImageSize:
"Calc size of `img` to fit in `crop_target` - adjust based on `do_crop`."
if crop_target is None: return None
ch,r,c = img.shape
target_r,target_c = crop_target
ratio = (min if do_crop else max)(r/target_r, c/target_c)
return ch,int(round(r/ratio)),int(round(c/ratio)) #Sometimes those are numpy numbers and round doesn't return an int.
def plot_flat(r, c, figsize):
"Shortcut for `enumerate(subplots.flatten())`"
return enumerate(plt.subplots(r, c, figsize=figsize)[1].flatten())
def plot_multi(func:Callable[[int,int,plt.Axes],None], r:int=1, c:int=1, figsize:Tuple=(12,6)):
"Call `func` for every combination of `r,c` on a subplot"
axes = plt.subplots(r, c, figsize=figsize)[1]
for i in range(r):
for j in range(c): func(i,j,axes[i,j])
def show_multi(func:Callable[[int,int],Image], r:int=1, c:int=1, figsize:Tuple=(9,9)):
"Call `func(i,j).show(ax)` for every combination of `r,c`"
plot_multi(lambda i,j,ax: func(i,j).show(ax), r, c, figsize=figsize)
def show_all(imgs:Collection[Image], r:int=1, c:Optional[int]=None, figsize=(12,6)):
"Show all `imgs` using `r` rows"
imgs = listify(imgs)
if c is None: c = len(imgs)//r
for i,ax in plot_flat(r,c,figsize): imgs[i].show(ax)
================================================
FILE: fastai/vision/interpret.py
================================================
from ..torch_core import *
from ..basic_data import *
from ..basic_train import *
from .image import *
from ..train import Interpretation
from textwrap import wrap
__all__ = ['SegmentationInterpretation', 'ObjectDetectionInterpretation']
class SegmentationInterpretation(Interpretation):
"Interpretation methods for segmenatation models."
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor,
ds_type:DatasetType=DatasetType.Valid):
super(SegmentationInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
self.pred_class = self.preds.argmax(dim=1)
self.c2i = {c:i for i,c in enumerate(self.data.classes)}
self.i2c = {i:c for c,i in self.c2i.items()}
def top_losses(self, sizes:Tuple, k:int=None, largest=True):
"Reduce flatten loss to give a single loss value for each image"
losses = self.losses.view(-1, np.prod(sizes)).mean(-1)
return losses.topk(ifnone(k, len(losses)), largest=largest)
def _interp_show(self, ims:ImageSegment, classes:Collection=None, sz:int=20, cmap='tab20',
title_suffix:str=None):
"Show ImageSegment with color mapping labels"
fig,axes=plt.subplots(1,2,figsize=(sz,sz))
np_im = to_np(ims.data).copy()
# tab20 - qualitative colormaps support max of 20 distinc colors
# if len(classes) > 20 close idxs map to same color
# image
if classes is not None:
class_idxs = [self.c2i[c] for c in classes]
mask = np.max(np.stack([np_im==i for i in class_idxs]),axis=0)
np_im = (np_im*mask).astype(np.float)
np_im[np.where(mask==0)] = np.nan
im=axes[0].imshow(np_im[0], cmap=cmap)
# labels
np_im_labels = list(np.unique(np_im[~np.isnan(np_im)]))
c = len(np_im_labels); n = math.ceil(np.sqrt(c))
label_im = np.array(np_im_labels + [np.nan]*(n**2-c)).reshape(n,n)
axes[1].imshow(label_im, cmap=cmap)
for i,l in enumerate([self.i2c[l] for l in np_im_labels]):
div,mod=divmod(i,n)
l = "\n".join(wrap(l,10)) if len(l) > 10 else l
axes[1].text(mod, div, f"{l}", ha='center', color='white', fontdict={'size':sz})
if title_suffix:
axes[0].set_title(f"{title_suffix}_imsegment")
axes[1].set_title(f"{title_suffix}_labels")
def show_xyz(self, i, classes:list=None, sz=10):
'show (image, true and pred) from self.ds with color mappings, optionally only plot'
x,y = self.ds[i]
self.ds.show_xys([x],[y], figsize=(sz/2,sz/2))
self._interp_show(ImageSegment(self.y_true[i]), classes, sz=sz, title_suffix='true')
self._interp_show(ImageSegment(self.pred_class[i][None,:]), classes, sz=sz, title_suffix='pred')
def _generate_confusion(self):
"Average and Per Image Confusion: intersection of pixels given a true label, true label sums to 1"
single_img_confusion = []
mean_confusion = []
n = self.pred_class.shape[0]
for c_j in range(self.data.c):
true_binary = self.y_true.squeeze(1) == c_j
total_true = true_binary.view(n,-1).sum(dim=1).float()
for c_i in range(self.data.c):
pred_binary = self.pred_class == c_i
total_intersect = (true_binary*pred_binary).view(n,-1).sum(dim=1).float()
p_given_t = (total_intersect / (total_true))
p_given_t_mean = p_given_t[~torch.isnan(p_given_t)].mean()
single_img_confusion.append(p_given_t)
mean_confusion.append(p_given_t_mean)
self.single_img_cm = to_np(torch.stack(single_img_confusion).permute(1,0).view(-1, self.data.c, self.data.c))
self.mean_cm = to_np(torch.tensor(mean_confusion).view(self.data.c, self.data.c))
return self.mean_cm, self.single_img_cm
def _plot_intersect_cm(self, cm, title="Intersection with Predict given True"):
"Plot confusion matrices: self.mean_cm or self.single_img_cm generated by `_generate_confusion`"
from IPython.display import display, HTML
fig,ax=plt.subplots(1,1,figsize=(10,10))
im=ax.imshow(cm, cmap="Blues")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title(f"{title}")
ax.set_xticks(range(self.data.c))
ax.set_yticks(range(self.data.c))
ax.set_xticklabels(self.data.classes, rotation='vertical')
ax.set_yticklabels(self.data.classes)
fig.colorbar(im)
df = (pd.DataFrame([self.data.classes, cm.diagonal()], index=['label', 'score'])
.T.sort_values('score', ascending=False))
with pd.option_context('display.max_colwidth', -1):
display(HTML(df.to_html(index=False)))
return df
class ObjectDetectionInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self, learn:Learner, preds:Tensor, y_true:Tensor, losses:Tensor, ds_type:DatasetType=DatasetType.Valid):
raise NotImplementedError
super(ObjectDetectionInterpretation, self).__init__(learn,preds,y_true,losses,ds_type)
================================================
FILE: fastai/vision/learner.py
================================================
"`Learner` support for computer vision"
from ..torch_core import *
from ..basic_train import *
from ..basic_data import *
from .image import *
from . import models
from ..callback import *
from ..layers import *
from ..callbacks.hooks import *
from ..train import ClassificationInterpretation
__all__ = ['cnn_learner', 'create_cnn', 'create_cnn_model', 'create_body', 'create_head', 'unet_learner']
# By default split models between first and second layer
def _default_split(m:nn.Module): return (m[1],)
# Split a resnet style model
def _resnet_split(m:nn.Module): return (m[0][6],m[1])
# Split squeezenet model on maxpool layers
def _squeezenet_split(m:nn.Module): return (m[0][0][5], m[0][0][8], m[1])
def _densenet_split(m:nn.Module): return (m[0][0][7],m[1])
def _vgg_split(m:nn.Module): return (m[0][0][22],m[1])
def _alexnet_split(m:nn.Module): return (m[0][0][6],m[1])
_default_meta = {'cut':None, 'split':_default_split}
_resnet_meta = {'cut':-2, 'split':_resnet_split }
_squeezenet_meta = {'cut':-1, 'split': _squeezenet_split}
_densenet_meta = {'cut':-1, 'split':_densenet_split}
_vgg_meta = {'cut':-1, 'split':_vgg_split}
_alexnet_meta = {'cut':-1, 'split':_alexnet_split}
model_meta = {
models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},
models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},
models.resnet152:{**_resnet_meta},
models.squeezenet1_0:{**_squeezenet_meta},
models.squeezenet1_1:{**_squeezenet_meta},
models.densenet121:{**_densenet_meta}, models.densenet169:{**_densenet_meta},
models.densenet201:{**_densenet_meta}, models.densenet161:{**_densenet_meta},
models.vgg16_bn:{**_vgg_meta}, models.vgg19_bn:{**_vgg_meta},
models.alexnet:{**_alexnet_meta}}
def cnn_config(arch):
"Get the metadata associated with `arch`."
#torch.backends.cudnn.benchmark = True
return model_meta.get(arch, _default_meta)
def has_pool_type(m):
if is_pool_type(m): return True
for l in m.children():
if has_pool_type(l): return True
return False
def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None):
"Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` (function)."
model = arch(pretrained=pretrained)
cut = ifnone(cut, cnn_config(arch)['cut'])
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif isinstance(cut, Callable): return cut(model)
else: raise NamedError("cut must be either integer or a function")
def create_head(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
concat_pool:bool=True, bn_final:bool=False):
"Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes."
lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
ps = listify(ps)
if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
layers = [pool, Flatten()]
for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):
layers += bn_drop_lin(ni, no, True, p, actn)
if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
return nn.Sequential(*layers)
def create_cnn_model(base_arch:Callable, nc:int, cut:Union[int,Callable]=None, pretrained:bool=True,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, custom_head:Optional[nn.Module]=None,
bn_final:bool=False, concat_pool:bool=True):
"Create custom convnet architecture"
body = create_body(base_arch, pretrained, cut)
if custom_head is None:
nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)
head = create_head(nf, nc, lin_ftrs, ps=ps, concat_pool=concat_pool, bn_final=bn_final)
else: head = custom_head
return nn.Sequential(body, head)
def cnn_learner(data:DataBunch, base_arch:Callable, cut:Union[int,Callable]=None, pretrained:bool=True,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5, custom_head:Optional[nn.Module]=None,
split_on:Optional[SplitFuncOrIdxList]=None, bn_final:bool=False, init=nn.init.kaiming_normal_,
concat_pool:bool=True, **kwargs:Any)->Learner:
"Build convnet style learner."
meta = cnn_config(base_arch)
model = create_cnn_model(base_arch, data.c, cut, pretrained, lin_ftrs, ps=ps, custom_head=custom_head,
bn_final=bn_final, concat_pool=concat_pool)
learn = Learner(data, model, **kwargs)
learn.split(split_on or meta['split'])
if pretrained: learn.freeze()
if init: apply_init(model[1], init)
return learn
def create_cnn(data, base_arch, **kwargs):
warn("`create_cnn` is deprecated and is now named `cnn_learner`.")
return cnn_learner(data, base_arch, **kwargs)
def unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, blur:bool=False,
self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
bottle:bool=False, cut:Union[int,Callable]=None, **learn_kwargs:Any)->Learner:
"Build Unet learner from `data` and `arch`."
meta = cnn_config(arch)
body = create_body(arch, pretrained, cut)
try: size = data.train_ds[0][0].size
except: size = next(iter(data.train_dl))[0].shape[-2:]
model = to_device(models.unet.DynamicUnet(body, n_classes=data.c, img_size=size, blur=blur, blur_final=blur_final,
self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
bottle=bottle), data.device)
learn = Learner(data, model, **learn_kwargs)
learn.split(ifnone(split_on, meta['split']))
if pretrained: learn.freeze()
apply_init(model[2], nn.init.kaiming_normal_)
return learn
@classmethod
def _cl_int_from_learner(cls, learn:Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, tta=False):
"Create an instance of `ClassificationInterpretation`. `tta` indicates if we want to use Test Time Augmentation."
preds = learn.TTA(ds_type=ds_type, with_loss=True) if tta else learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)
return cls(learn, *preds, ds_type=ds_type)
def _test_cnn(m):
if not isinstance(m, nn.Sequential) or not len(m) == 2: return False
return isinstance(m[1][0], (AdaptiveConcatPool2d, nn.AdaptiveAvgPool2d))
def _cl_int_gradcam(self, idx, heatmap_thresh:int=16, image:bool=True):
m = self.learn.model.eval()
im,cl = self.learn.data.dl(DatasetType.Valid).dataset[idx]
cl = int(cl)
xb,_ = self.data.one_item(im, detach=False, denorm=False) #put into a minibatch of batch size = 1
with hook_output(m[0]) as hook_a:
with hook_output(m[0], grad=True) as hook_g:
preds = m(xb)
preds[0,int(cl)].backward()
acts = hook_a.stored[0].cpu() #activation maps
if (acts.shape[-1]*acts.shape[-2]) >= heatmap_thresh:
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
mult = F.relu(((acts*grad_chan[...,None,None])).sum(0))
if image:
xb_im = Image(xb[0])
_,ax = plt.subplots()
sz = list(xb_im.shape[-2:])
xb_im.show(ax,title=f"pred. class: {self.pred_class[idx]}, actual class: {self.learn.data.classes[cl]}")
ax.imshow(mult, alpha=0.4, extent=(0,*sz[::-1],0),
interpolation='bilinear', cmap='magma')
return mult
ClassificationInterpretation.GradCAM =_cl_int_gradcam
def _cl_int_plot_top_losses(self, k, largest=True, figsize=(12,12), heatmap:bool=False, heatmap_thresh:int=16,
return_fig:bool=None)->Optional[plt.Figure]:
"Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class."
assert not heatmap or _test_cnn(self.learn.model), "`heatmap=True` requires a model like `cnn_learner` produces."
if heatmap is None: heatmap = _test_cnn(self.learn.model)
tl_val,tl_idx = self.top_losses(k, largest)
classes = self.data.classes
cols = math.ceil(math.sqrt(k))
rows = math.ceil(k/cols)
fig,axes = plt.subplots(rows, cols, figsize=figsize)
fig.suptitle('prediction/actual/loss/probability', weight='bold', size=14)
for i,idx in enumerate(tl_idx):
im,cl = self.data.dl(self.ds_type).dataset[idx]
cl = int(cl)
im.show(ax=axes.flat[i], title=
f'{classes[self.pred_class[idx]]}/{classes[cl]} / {self.losses[idx]:.2f} / {self.preds[idx][cl]:.2f}')
if heatmap:
mult = self.GradCAM(idx,heatmap_thresh,image=False)
if mult is not None:
sz = list(im.shape[-2:])
axes.flat[i].imshow(mult, alpha=0.6, extent=(0,*sz[::-1],0), interpolation='bilinear', cmap='magma')
if ifnone(return_fig, defaults.return_fig): return fig
def _cl_int_plot_multi_top_losses(self, samples:int=3, figsize:Tuple[int,int]=(8,8), save_misclassified:bool=False):
"Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class in a multilabeled dataset."
if samples >20:
print("Max 20 samples")
return
losses, idxs = self.top_losses(self.data.c)
l_dim = len(losses.size())
if l_dim == 1: losses, idxs = self.top_losses()
infolist, ordlosses_idxs, mismatches_idxs, mismatches, losses_mismatches, mismatchescontainer = [],[],[],[],[],[]
truthlabels = np.asarray(self.y_true, dtype=int)
classes_ids = [k for k in enumerate(self.data.classes)]
predclass = np.asarray(self.pred_class)
for i,pred in enumerate(predclass):
where_truth = np.nonzero((truthlabels[i]>0))[0]
mismatch = np.all(pred!=where_truth)
if mismatch:
mismatches_idxs.append(i)
if l_dim > 1 : losses_mismatches.append((losses[i][pred], i))
else: losses_mismatches.append((losses[i], i))
if l_dim > 1: infotup = (i, pred, where_truth, losses[i][pred], np.round(self.preds[i], decimals=3)[pred], mismatch)
else: infotup = (i, pred, where_truth, losses[i], np.round(self.preds[i], decimals=3)[pred], mismatch)
infolist.append(infotup)
ds = self.data.dl(self.ds_type).dataset
mismatches = ds[mismatches_idxs]
ordlosses = sorted(losses_mismatches, key = lambda x: x[0], reverse=True)
for w in ordlosses: ordlosses_idxs.append(w[1])
mismatches_ordered_byloss = ds[ordlosses_idxs]
print(f'{str(len(mismatches))} misclassified samples over {str(len(self.data.valid_ds))} samples in the validation set.')
samples = min(samples, len(mismatches))
for ima in range(len(mismatches_ordered_byloss)):
mismatchescontainer.append(mismatches_ordered_byloss[ima][0])
for sampleN in range(samples):
actualclasses = ''
for clas in infoList[ordlosses_idxs[sampleN]][2]:
actualclasses = f'{actualclasses} -- {str(classes_ids[clas][1])}'
imag = mismatches_ordered_byloss[sampleN][0]
imag = show_image(imag, figsize=figsize)
imag.set_title(f"""Predicted: {classes_ids[infoList[ordlosses_idxs[sampleN]][1]][1]} \nActual: {actualclasses}\nLoss: {infoList[ordlosses_idxs[sampleN]][3]}\nProbability: {infoList[ordlosses_idxs[sampleN]][4]}""",
loc='left')
plt.show()
if save_misclassified: return mismatchescontainer
ClassificationInterpretation.from_learner = _cl_int_from_learner
ClassificationInterpretation.plot_top_losses = _cl_int_plot_top_losses
ClassificationInterpretation.plot_multi_top_losses = _cl_int_plot_multi_top_losses
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)
Learner.interpret = _learner_interpret
================================================
FILE: fastai/vision/models/__init__.py
================================================
from .xresnet import *
from torchvision.models import ResNet,resnet18,resnet34,resnet50,resnet101,resnet152
from torchvision.models import SqueezeNet,squeezenet1_0,squeezenet1_1
from torchvision.models import densenet121,densenet169,densenet201,densenet161
from torchvision.models import vgg16_bn,vgg19_bn,alexnet
from .darknet import *
from .unet import *
from .wrn import *
from .xception import *
================================================
FILE: fastai/vision/models/cadene_models.py
================================================
#These models are dowloaded via the repo https://github.com/Cadene/pretrained-models.pytorch
#See licence here: https://github.com/Cadene/pretrained-models.pytorch/blob/master/LICENSE.txt
from torch import nn
from ..learner import model_meta
from ...core import *
pretrainedmodels = try_import('pretrainedmodels')
if not pretrainedmodels:
raise Exception('Error: `pretrainedmodels` is needed. `pip install pretrainedmodels`')
__all__ = ['inceptionv4', 'inceptionresnetv2', 'nasnetamobile', 'dpn92', 'xception_cadene', 'se_resnet50',
'se_resnet101', 'se_resnext50_32x4d', 'senet154', 'pnasnet5large', 'se_resnext101_32x4d']
def get_model(model_name:str, pretrained:bool, seq:bool=False, pname:str='imagenet', **kwargs):
pretrained = pname if pretrained else None
model = getattr(pretrainedmodels, model_name)(pretrained=pretrained, **kwargs)
return nn.Sequential(*model.children()) if seq else model
def inceptionv4(pretrained:bool=False):
model = get_model('inceptionv4', pretrained)
all_layers = list(model.children())
return nn.Sequential(*all_layers[0], *all_layers[1:])
model_meta[inceptionv4] = {'cut': -2, 'split': lambda m: (m[0][11], m[1])}
def nasnetamobile(pretrained:bool=False):
model = get_model('nasnetamobile', pretrained, num_classes=1000)
model.logits = noop
return nn.Sequential(model)
model_meta[nasnetamobile] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}
def pnasnet5large(pretrained:bool=False):
model = get_model('pnasnet5large', pretrained, num_classes=1000)
model.logits = noop
return nn.Sequential(model)
model_meta[pnasnet5large] = {'cut': noop, 'split': lambda m: (list(m[0][0].children())[8], m[1])}
def inceptionresnetv2(pretrained:bool=False): return get_model('inceptionresnetv2', pretrained, seq=True)
def dpn92(pretrained:bool=False): return get_model('dpn92', pretrained, pname='imagenet+5k', seq=True)
def xception_cadene(pretrained=False): return get_model('xception', pretrained, seq=True)
def se_resnet50(pretrained:bool=False): return get_model('se_resnet50', pretrained)
def se_resnet101(pretrained:bool=False): return get_model('se_resnet101', pretrained)
def se_resnext50_32x4d(pretrained:bool=False): return get_model('se_resnext50_32x4d', pretrained)
def se_resnext101_32x4d(pretrained:bool=False): return get_model('se_resnext101_32x4d', pretrained)
def senet154(pretrained:bool=False): return get_model('senet154', pretrained)
model_meta[inceptionresnetv2] = {'cut': -2, 'split': lambda m: (m[0][9], m[1])}
model_meta[dpn92] = {'cut': -1, 'split': lambda m: (m[0][0][16], m[1])}
model_meta[xception_cadene] = {'cut': -1, 'split': lambda m: (m[0][11], m[1])}
model_meta[senet154] = {'cut': -3, 'split': lambda m: (m[0][3], m[1])}
_se_resnet_meta = {'cut': -2, 'split': lambda m: (m[0][3], m[1])}
model_meta[se_resnet50] = _se_resnet_meta
model_meta[se_resnet101] = _se_resnet_meta
model_meta[se_resnext50_32x4d] = _se_resnet_meta
model_meta[se_resnext101_32x4d] = _se_resnet_meta
# TODO: add "resnext101_32x4d" "resnext101_64x4d" after serialization issue is fixed:
# https://github.com/Cadene/pretrained-models.pytorch/pull/128
================================================
FILE: fastai/vision/models/darknet.py
================================================
from ...torch_core import *
from ...layers import *
__all__ = ['Darknet', 'ResLayer']
def conv_bn_lrelu(ni:int, nf:int, ks:int=3, stride:int=1)->nn.Sequential:
"Create a seuence Conv2d->BatchNorm2d->LeakyReLu layer."
return nn.Sequential(
nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),
nn.BatchNorm2d(nf),
nn.LeakyReLU(negative_slope=0.1, inplace=True))
class ResLayer(Module):
"Resnet style layer with `ni` inputs."
def __init__(self, ni:int):
self.conv1 = conv_bn_lrelu(ni, ni//2, ks=1)
self.conv2 = conv_bn_lrelu(ni//2, ni, ks=3)
def forward(self, x): return x + self.conv2(self.conv1(x))
class Darknet(Module):
"https://github.com/pjreddie/darknet"
def make_group_layer(self, ch_in:int, num_blocks:int, stride:int=1):
"starts with conv layer - `ch_in` channels in - then has `num_blocks` `ResLayer`"
return [conv_bn_lrelu(ch_in, ch_in*2,stride=stride)
] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]
def __init__(self, num_blocks:Collection[int], num_classes:int, nf=32):
"create darknet with `nf` and `num_blocks` layers"
layers = [conv_bn_lrelu(3, nf, ks=3, stride=1)]
for i,nb in enumerate(num_blocks):
layers += self.make_group_layer(nf, nb, stride=2-(i==1))
nf *= 2
layers += [nn.AdaptiveAvgPool2d(1), Flatten(), nn.Linear(nf, num_classes)]
self.layers = nn.Sequential(*layers)
def forward(self, x): return self.layers(x)
================================================
FILE: fastai/vision/models/presnet.py
================================================
from pdb import set_trace
import torch.nn.functional as F
import torch.nn as nn
import torch
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['PResNet', 'presnet18', 'presnet34', 'presnet50', 'presnet101', 'presnet152']
act_fn = nn.ReLU
def init_cnn(m):
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
for l in m.children(): init_cnn(l)
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def conv_layer(conv_1st, ni, nf, ks=3, stride=1, zero_bn=False, bias=False):
bn = nn.BatchNorm2d(nf if conv_1st else ni)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
res = [act_fn(), bn]
cn = conv(ni, nf, ks, stride=stride, bias=bias)
res.insert(0 if conv_1st else 2, cn)
return nn.Sequential(*res)
def conv_act(*args, **kwargs): return conv_layer(True , *args, **kwargs)
def act_conv(*args, **kwargs): return conv_layer(False, *args, **kwargs)
class BasicBlock(Module):
expansion = 1
def __init__(self, ni, nf, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = act_conv(ni, nf, stride=stride)
self.conv2 = act_conv(nf, nf, zero_bn=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x if self.downsample is None else self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x += identity
return x
class Bottleneck(Module):
expansion = 4
def __init__(self, ni, nf, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = act_conv(ni, nf, 1)
self.conv2 = act_conv(nf, nf, stride=stride)
self.conv3 = act_conv(nf, nf*self.expansion, 1)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x if self.downsample is None else self.downsample(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x += identity
return x
class PResNet(Module):
def __init__(self, block, layers, num_classes=1000):
self.ni = 64
super().__init__()
self.conv1 = conv_act(3, 16, stride=2)
self.conv2 = conv_act(16, 32)
self.conv3 = conv_act(32, 64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
ni = 512*block.expansion
self.avgpool = nn.Sequential(
act_fn(), nn.BatchNorm2d(ni), nn.AdaptiveAvgPool2d(1))
self.fc = nn.Linear(ni, num_classes)
init_cnn(self)
def _make_layer(self, block, nf, blocks, stride=1):
downsample = None
if stride != 1 or self.ni != nf*block.expansion:
layers = [act_fn(), nn.BatchNorm2d(self.ni),
nn.AvgPool2d(kernel_size=2)] if stride==2 else []
layers.append(conv(self.ni, nf*block.expansion))
downsample = nn.Sequential(*layers)
layers = [block(self.ni, nf, stride, downsample)]
self.ni = nf*block.expansion
for i in range(1, blocks): layers.append(block(self.ni, nf))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model_urls = dict(presnet34='presnet34', presnet50='presnet50')
def presnet(block, n_layers, name, pre=False, **kwargs):
model = PResNet(block, n_layers, **kwargs)
#if pre: model.load_state_dict(model_zoo.load_url(model_urls[name]))
if pre: model.load_state_dict(torch.load(model_urls[name]))
return model
def presnet18(pretrained=False, **kwargs):
return presnet(BasicBlock, [2, 2, 2, 2], 'presnet18', pre=pretrained, **kwargs)
def presnet34(pretrained=False, **kwargs):
return presnet(BasicBlock, [3, 4, 6, 3], 'presnet34', pre=pretrained, **kwargs)
def presnet50(pretrained=False, **kwargs):
return presnet(Bottleneck, [3, 4, 6, 3], 'presnet50', pre=pretrained, **kwargs)
def presnet101(pretrained=False, **kwargs):
return presnet(Bottleneck, [3, 4, 23, 3], 'presnet101', pre=pretrained, **kwargs)
def presnet152(pretrained=False, **kwargs):
return presnet(Bottleneck, [3, 8, 36, 3], 'presnet152', pre=pretrained, **kwargs)
================================================
FILE: fastai/vision/models/unet.py
================================================
from ...torch_core import *
from ...layers import *
from ...callbacks.hooks import *
__all__ = ['DynamicUnet', 'UnetBlock']
def _get_sfs_idxs(sizes:Sizes) -> List[int]:
"Get the indexes of the layers where the size of the activation changes."
feature_szs = [size[-1] for size in sizes]
sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
return sfs_idxs
class UnetBlock(Module):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
def __init__(self, up_in_c:int, x_in_c:int, hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
self_attention:bool=False, **kwargs):
self.hook = hook
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs)
self.bn = batchnorm_2d(x_in_c)
ni = up_in_c//2 + x_in_c
nf = ni if final_div else ni//2
self.conv1 = conv_layer(ni, nf, leaky=leaky, **kwargs)
self.conv2 = conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
self.relu = relu(leaky=leaky)
def forward(self, up_in:Tensor) -> Tensor:
s = self.hook.stored
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv2(self.conv1(cat_x))
class DynamicUnet(SequentialEx):
"Create a U-Net from a given architecture."
def __init__(self, encoder:nn.Module, n_classes:int, img_size:Tuple[int,int]=(256,256), blur:bool=False, blur_final=True, self_attention:bool=False,
y_range:Optional[Tuple[float,float]]=None,
last_cross:bool=True, bottle:bool=False, **kwargs):
imsize = img_size
sfs_szs = model_sizes(encoder, size=imsize)
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
x = dummy_eval(encoder, imsize).detach()
ni = sfs_szs[-1][1]
middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs),
conv_layer(ni*2, ni, **kwargs)).eval()
x = middle_conv(x)
layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
for i,idx in enumerate(sfs_idxs):
not_final = i!=len(sfs_idxs)-1
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
do_blur = blur and (not_final or blur_final)
sa = self_attention and (i==len(sfs_idxs)-3)
unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
**kwargs).eval()
layers.append(unet_block)
x = unet_block(x)
ni = x.shape[1]
if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
x = PixelShuffle_ICNR(ni)(x)
if imsize != x.shape[-2:]: layers.append(Lambda(lambda x: F.interpolate(x, imsize, mode='nearest')))
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(res_block(ni, bottle=bottle, **kwargs))
layers += [conv_layer(ni, n_classes, ks=1, use_activ=False, **kwargs)]
if y_range is not None: layers.append(SigmoidRange(*y_range))
super().__init__(*layers)
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()
================================================
FILE: fastai/vision/models/wrn.py
================================================
from ...layers import *
from ...torch_core import *
__all__ = ['BasicBlock', 'WideResNet', 'wrn_22']
def _bn(ni, init_zero=False):
"Batchnorm layer with 0 initialization"
m = nn.BatchNorm2d(ni)
m.weight.data.fill_(0 if init_zero else 1)
m.bias.data.zero_()
return m
def bn_relu_conv(ni, nf, ks, stride, init_zero=False):
bn_initzero = _bn(ni, init_zero=init_zero)
return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv2d(ni, nf, ks, stride))
class BasicBlock(Module):
"Block to from a wide ResNet."
def __init__(self, ni, nf, stride, drop_p=0.0):
self.bn = nn.BatchNorm2d(ni)
self.conv1 = conv2d(ni, nf, 3, stride)
self.conv2 = bn_relu_conv(nf, nf, 3, 1)
self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
self.shortcut = conv2d(ni, nf, 1, stride) if ni != nf else noop
def forward(self, x):
x2 = F.relu(self.bn(x), inplace=True)
r = self.shortcut(x2)
x = self.conv1(x2)
if self.drop: x = self.drop(x)
x = self.conv2(x) * 0.2
return x.add_(r)
def _make_group(N, ni, nf, block, stride, drop_p):
return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]
class WideResNet(Module):
"Wide ResNet with `num_groups` and a width of `k`."
def __init__(self, num_groups:int, N:int, num_classes:int, k:int=1, drop_p:float=0.0, start_nf:int=16, n_in_channels:int=3):
n_channels = [start_nf]
for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)
layers = [conv2d(n_in_channels, n_channels[0], 3, 1)] # conv1
for i in range(num_groups):
layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)
layers += [nn.BatchNorm2d(n_channels[num_groups]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
Flatten(), nn.Linear(n_channels[num_groups], num_classes)]
self.features = nn.Sequential(*layers)
def forward(self, x): return self.features(x)
def wrn_22():
"Wide ResNet with 22 layers."
return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)
================================================
FILE: fastai/vision/models/xception.py
================================================
from ...vision import *
__all__ = ['xception']
def sep_conv(ni,nf,pad=None,pool=False,act=True):
layers = [nn.ReLU()] if act else []
layers += [
nn.Conv2d(ni,ni,3,1,1,groups=ni,bias=False),
nn.Conv2d(ni,nf,1,bias=False),
nn.BatchNorm2d(nf)
]
if pool: layers.append(nn.MaxPool2d(2))
return nn.Sequential(*layers)
def conv(ni,nf,ks=1,stride=1, pad=None, act=True):
if pad is None: pad=ks//2
layers = [
nn.Conv2d(ni,nf,ks,stride,pad,bias=False),
nn.BatchNorm2d(nf),
]
if act: layers.append(nn.ReLU())
return nn.Sequential(*layers)
class ConvSkip(Module):
def __init__(self,ni,nf=None,act=True):
self.nf,self.ni = nf,ni
if self.nf is None: self.nf = ni
self.conv = conv(ni,nf,stride=2, act=False)
self.m = nn.Sequential(
sep_conv(ni,ni,act=act),
sep_conv(ni,nf,pool=True)
)
def forward(self,x): return self.conv(x) + self.m(x)
def middle_flow(nf):
layers = [sep_conv(nf,nf) for i in range(3)]
return SequentialEx(*layers, MergeLayer())
def xception(c, k=8, n_middle=8):
"Preview version of Xception network. Not tested yet - use at own risk. No pretrained model yet."
layers = [
conv(3, k*4, 3, 2),
conv(k*4, k*8, 3),
ConvSkip(k*8, k*16, act=False),
ConvSkip(k*16, k*32),
ConvSkip(k*32, k*91),
]
for i in range(n_middle): layers.append(middle_flow(k*91))
layers += [
ConvSkip(k*91,k*128),
sep_conv(k*128,k*192,act=False),
sep_conv(k*192,k*256),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
Flatten(),
nn.Linear(k*256,c)
]
return nn.Sequential(*layers)
================================================
FILE: fastai/vision/models/xresnet.py
================================================
import torch.nn as nn
import torch,math,sys
import torch.utils.model_zoo as model_zoo
from functools import partial
from ...torch_core import Module
__all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152']
# or: ELU+init (a=0.54; gain=1.55)
act_fn = nn.ReLU(inplace=True)
class Flatten(Module):
def forward(self, x): return x.view(x.size(0), -1)
def init_cnn(m):
if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
for l in m.children(): init_cnn(l)
def conv(ni, nf, ks=3, stride=1, bias=False):
return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)
def noop(x): return x
def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
bn = nn.BatchNorm2d(nf)
nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
layers = [conv(ni, nf, ks, stride=stride), bn]
if act: layers.append(act_fn)
return nn.Sequential(*layers)
class ResBlock(Module):
def __init__(self, expansion, ni, nh, stride=1):
nf,ni = nh*expansion,ni*expansion
layers = [conv_layer(ni, nh, 3, stride=stride),
conv_layer(nh, nf, 3, zero_bn=True, act=False)
] if expansion == 1 else [
conv_layer(ni, nh, 1),
conv_layer(nh, nh, 3, stride=stride),
conv_layer(nh, nf, 1, zero_bn=True, act=False)
]
self.convs = nn.Sequential(*layers)
# TODO: check whether act=True works better
self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))
def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75)))
class XResNet(nn.Sequential):
def __init__(self, expansion, layers, c_in=3, c_out=1000):
stem = []
sizes = [c_in,32,32,64]
for i in range(3):
stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))
#nf = filt_sz(c_in*9)
#stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1))
#c_in = nf
block_szs = [64//expansion,64,128,256,512]
blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2)
for i,l in enumerate(layers)]
super().__init__(
*stem,
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
*blocks,
nn.AdaptiveAvgPool2d(1), Flatten(),
nn.Linear(block_szs[-1]*expansion, c_out),
)
init_cnn(self)
def _make_layer(self, expansion, ni, nf, blocks, stride):
return nn.Sequential(
*[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1)
for i in range(blocks)])
def xresnet(expansion, n_layers, name, pretrained=False, **kwargs):
model = XResNet(expansion, n_layers, **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name]))
return model
me = sys.modules[__name__]
for n,e,l in [
[ 18 , 1, [2,2,2 ,2] ],
[ 34 , 1, [3,4,6 ,3] ],
[ 50 , 4, [3,4,6 ,3] ],
[ 101, 4, [3,4,23,3] ],
[ 152, 4, [3,8,36,3] ],
]:
name = f'xresnet{n}'
setattr(me, name, partial(xresnet, expansion=e, n_layers=l, name=name))
================================================
FILE: fastai/vision/models/xresnet2.py
================================================
import torch.nn as nn
import torch
import math
import torch.utils.model_zoo as model_zoo
from ...torch_core import Module
__all__ = ['XResNet', 'xresnet18', 'xresnet34_2', 'xresnet50_2', 'xresnet101', 'xresnet152']
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None: residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None: residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def conv2d(ni, nf, stride):
return nn.Sequential(nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(nf), nn.ReLU(inplace=True))
class XResNet(Module):
def __init__(self, block, layers, c_out=1000):
self.inplanes = 64
super(XResNet, self).__init__()
self.conv1 = conv2d(3, 32, 2)
self.conv2 = conv2d(32, 32, 1)
self.conv3 = conv2d(32, 64, 1)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512 * block.expansion, c_out)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for m in self.modules():
if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight))
if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight))
if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
layers = []
if stride==2: layers.append(nn.AvgPool2d(kernel_size=2, stride=2))
layers += [
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(planes * block.expansion) ]
downsample = nn.Sequential(*layers)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks): layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def xresnet18(pretrained=False, **kwargs):
"""Constructs a XResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet18']))
return model
def xresnet34_2(pretrained=False, **kwargs):
"""Constructs a XResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet34']))
return model
def xresnet50_2(pretrained=False, **kwargs):
"""Constructs a XResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet50']))
return model
def xresnet101(pretrained=False, **kwargs):
"""Constructs a XResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet101']))
return model
def xresnet152(pretrained=False, **kwargs):
"""Constructs a XResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = XResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['xresnet152']))
return model
================================================
FILE: fastai/vision/transform.py
================================================
"Image transformations for data augmentation. All transforms are done on the tensor level"
from ..torch_core import *
from .image import *
from .image import _affine_mult
__all__ = ['brightness', 'contrast', 'crop', 'crop_pad', 'cutout', 'dihedral', 'dihedral_affine', 'flip_affine', 'flip_lr',
'get_transforms', 'jitter', 'pad', 'perspective_warp', 'rand_pad', 'rand_crop', 'rand_zoom', 'rgb_randomize', 'rotate', 'skew', 'squish',
'rand_resize_crop', 'symmetric_warp', 'tilt', 'zoom', 'zoom_crop']
_pad_mode_convert = {'reflection':'reflect', 'zeros':'constant', 'border':'replicate'}
#NB: Although TfmLighting etc can be used as decorators, that doesn't work in Windows,
# so we do it manually for now.
def _brightness(x, change:uniform):
"Apply `change` in brightness of image `x`."
return x.add_(scipy.special.logit(change))
brightness = TfmLighting(_brightness)
def _contrast(x, scale:log_uniform):
"Apply `scale` to contrast of image `x`."
return x.mul_(scale)
contrast = TfmLighting(_contrast)
def _rotate(degrees:uniform):
"Rotate image by `degrees`."
angle = degrees * math.pi / 180
return [[float(cos(angle)), float(-sin(angle)), 0.],
[float(sin(angle)), float(cos(angle)), 0.],
[0. , 0. , 1.]]
rotate = TfmAffine(_rotate)
def _get_zoom_mat(sw:float, sh:float, c:float, r:float)->AffineMatrix:
"`sw`,`sh` scale width,height - `c`,`r` focus col,row."
return [[sw, 0, c],
[0, sh, r],
[0, 0, 1.]]
def _zoom(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5):
"Zoom image by `scale`. `row_pct`,`col_pct` select focal point of zoom."
s = 1-1/scale
col_c = s * (2*col_pct - 1)
row_c = s * (2*row_pct - 1)
return _get_zoom_mat(1/scale, 1/scale, col_c, row_c)
zoom = TfmAffine(_zoom)
def _squish(scale:uniform=1.0, row_pct:uniform=0.5, col_pct:uniform=0.5):
"Squish image by `scale`. `row_pct`,`col_pct` select focal point of zoom."
if scale <= 1:
col_c = (1-scale) * (2*col_pct - 1)
return _get_zoom_mat(scale, 1, col_c, 0.)
else:
row_c = (1-1/scale) * (2*row_pct - 1)
return _get_zoom_mat(1, 1/scale, 0., row_c)
squish = TfmAffine(_squish)
def _jitter(c, magnitude:uniform):
"Replace pixels by random neighbors at `magnitude`."
c.flow.add_((torch.rand_like(c.flow)-0.5)*magnitude*2)
return c
jitter = TfmCoord(_jitter)
def _flip_lr(x):
"Flip `x` horizontally."
#return x.flip(2)
if isinstance(x, ImagePoints):
x.flow.flow[...,0] *= -1
return x
return tensor(np.ascontiguousarray(np.array(x)[...,::-1]))
flip_lr = TfmPixel(_flip_lr)
def _flip_affine() -> TfmAffine:
"Flip `x` horizontally."
return [[-1, 0, 0.],
[0, 1, 0],
[0, 0, 1.]]
flip_affine = TfmAffine(_flip_affine)
def _dihedral(x, k:partial(uniform_int,0,7)):
"Randomly flip `x` image based on `k`."
flips=[]
if k&1: flips.append(1)
if k&2: flips.append(2)
if flips: x = torch.flip(x,flips)
if k&4: x = x.transpose(1,2)
return x.contiguous()
dihedral = TfmPixel(_dihedral)
def _dihedral_affine(k:partial(uniform_int,0,7)):
"Randomly flip `x` image based on `k`."
x = -1 if k&1 else 1
y = -1 if k&2 else 1
if k&4: return [[0, x, 0.],
[y, 0, 0],
[0, 0, 1.]]
return [[x, 0, 0.],
[0, y, 0],
[0, 0, 1.]]
dihedral_affine = TfmAffine(_dihedral_affine)
def _pad_coord(x, row_pad:int, col_pad:int, mode='zeros'):
#TODO: implement other padding modes than zeros?
h,w = x.size
pad = torch.Tensor([w/(w + 2*col_pad), h/(h + 2*row_pad)])
x.flow = FlowField((h+2*row_pad, w+2*col_pad) , x.flow.flow * pad[None])
return x
def _pad_default(x, padding:int, mode='reflection'):
"Pad `x` with `padding` pixels. `mode` fills in space ('zeros','reflection','border')."
mode = _pad_mode_convert[mode]
return F.pad(x[None], (padding,)*4, mode=mode)[0]
def _pad_image_points(x, padding:int, mode='reflection'):
return _pad_coord(x, padding, padding, mode)
def _pad(x, padding:int, mode='reflection'):
f_pad = _pad_image_points if isinstance(x, ImagePoints) else _pad_default
return f_pad(x, padding, mode)
pad = TfmPixel(_pad, order=-10)
def _cutout(x, n_holes:uniform_int=1, length:uniform_int=40):
"Cut out `n_holes` number of square holes of size `length` in image at random locations."
h,w = x.shape[1:]
for n in range(n_holes):
h_y = np.random.randint(0, h)
h_x = np.random.randint(0, w)
y1 = int(np.clip(h_y - length / 2, 0, h))
y2 = int(np.clip(h_y + length / 2, 0, h))
x1 = int(np.clip(h_x - length / 2, 0, w))
x2 = int(np.clip(h_x + length / 2, 0, w))
x[:, y1:y2, x1:x2] = 0
return x
cutout = TfmPixel(_cutout, order=20)
def _rgb_randomize(x, channel:int=None, thresh:float=0.3):
"Randomize one of the channels of the input image"
if channel is None: channel = np.random.randint(0, x.shape[0] - 1)
x[channel] = torch.rand(x.shape[1:]) * np.random.uniform(0, thresh)
return x
rgb_randomize = TfmPixel(_rgb_randomize)
def _minus_epsilon(row_pct:float, col_pct:float, eps:float=1e-7):
if row_pct==1.: row_pct -= 1e-7
if col_pct==1.: col_pct -= 1e-7
return row_pct,col_pct
def _crop_default(x, size, row_pct:uniform=0.5, col_pct:uniform=0.5):
"Crop `x` to `size` pixels. `row_pct`,`col_pct` select focal point of crop."
rows,cols = tis2hw(size)
row_pct,col_pct = _minus_epsilon(row_pct,col_pct)
row = int((x.size(1)-rows+1) * row_pct)
col = int((x.size(2)-cols+1) * col_pct)
return x[:, row:row+rows, col:col+cols].contiguous()
def _crop_image_points(x, size, row_pct=0.5, col_pct=0.5):
h,w = x.size
rows,cols = tis2hw(size)
row_pct,col_pct = _minus_epsilon(row_pct,col_pct)
x.flow.flow.mul_(torch.Tensor([w/cols, h/rows])[None])
row = int((h-rows+1) * row_pct)
col = int((w-cols+1) * col_pct)
x.flow.flow.add_(-1 + torch.Tensor([w/cols-2*col/cols, h/rows-2*row/rows])[None])
x.size = (rows, cols)
return x
def _crop(x, size, row_pct:uniform=0.5, col_pct:uniform=0.5):
f_crop = _crop_image_points if isinstance(x, ImagePoints) else _crop_default
return f_crop(x, size, row_pct, col_pct)
crop = TfmPixel(_crop)
def _crop_pad_default(x, size, padding_mode='reflection', row_pct:uniform = 0.5, col_pct:uniform = 0.5):
"Crop and pad tfm - `row_pct`,`col_pct` sets focal point."
padding_mode = _pad_mode_convert[padding_mode]
size = tis2hw(size)
if x.shape[1:] == torch.Size(size): return x
rows,cols = size
row_pct,col_pct = _minus_epsilon(row_pct,col_pct)
if x.size(1)Tensor:
"Find 8 coeff mentioned [here](https://web.archive.org/web/20150222120106/xenia.media.mit.edu/~cwren/interpolator/)."
matrix = []
#The equations we'll need to solve.
for p1, p2 in zip(targ_pts, orig_pts):
matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])
A = FloatTensor(matrix)
B = FloatTensor(orig_pts).view(8, 1)
#The 8 scalars we seek are solution of AX = B
return torch.linalg.solve(A,B)[:,0]
def _apply_perspective(coords:FlowField, coeffs:Points)->FlowField:
"Transform `coords` with `coeffs`."
size = coords.flow.size()
#compress all the dims expect the last one ang adds ones, coords become N * 3
coords.flow = coords.flow.view(-1,2)
#Transform the coeffs in a 3*3 matrix with a 1 at the bottom left
coeffs = torch.cat([coeffs, FloatTensor([1])]).view(3,3)
coords.flow = torch.addmm(coeffs[:,2], coords.flow, coeffs[:,:2].t())
coords.flow.mul_(1/coords.flow[:,2].unsqueeze(1))
coords.flow = coords.flow[:,:2].view(size)
return coords
_orig_pts = [[-1,-1], [-1,1], [1,-1], [1,1]]
def _do_perspective_warp(c:FlowField, targ_pts:Points, invert=False):
"Apply warp to `targ_pts` from `_orig_pts` to `c` `FlowField`."
if invert: return _apply_perspective(c, _find_coeffs(targ_pts, _orig_pts))
return _apply_perspective(c, _find_coeffs(_orig_pts, targ_pts))
def _perspective_warp(c, magnitude:partial(uniform,size=8)=0, invert=False):
"Apply warp of `magnitude` to `c`."
magnitude = magnitude.view(4,2)
targ_pts = [[x+m for x,m in zip(xs, ms)] for xs, ms in zip(_orig_pts, magnitude)]
return _do_perspective_warp(c, targ_pts, invert)
perspective_warp = TfmCoord(_perspective_warp)
def _symmetric_warp(c, magnitude:partial(uniform,size=4)=0, invert=False):
"Apply symmetric warp of `magnitude` to `c`."
m = listify(magnitude, 4)
targ_pts = [[-1-m[3],-1-m[1]], [-1-m[2],1+m[1]], [1+m[3],-1-m[0]], [1+m[2],1+m[0]]]
return _do_perspective_warp(c, targ_pts, invert)
symmetric_warp = TfmCoord(_symmetric_warp)
def _tilt(c, direction:uniform_int, magnitude:uniform=0, invert=False):
"Tilt `c` field with random `direction` and `magnitude`."
orig_pts = [[-1,-1], [-1,1], [1,-1], [1,1]]
if direction == 0: targ_pts = [[-1,-1], [-1,1], [1,-1-magnitude], [1,1+magnitude]]
elif direction == 1: targ_pts = [[-1,-1-magnitude], [-1,1+magnitude], [1,-1], [1,1]]
elif direction == 2: targ_pts = [[-1,-1], [-1-magnitude,1], [1,-1], [1+magnitude,1]]
elif direction == 3: targ_pts = [[-1-magnitude,-1], [-1,1], [1+magnitude,-1], [1,1]]
coeffs = _find_coeffs(targ_pts, _orig_pts) if invert else _find_coeffs(_orig_pts, targ_pts)
return _apply_perspective(c, coeffs)
tilt = TfmCoord(_tilt)
def _skew(c, direction:uniform_int, magnitude:uniform=0, invert=False):
"Skew `c` field with random `direction` and `magnitude`."
orig_pts = [[-1,-1], [-1,1], [1,-1], [1,1]]
if direction == 0: targ_pts = [[-1-magnitude,-1], [-1,1], [1,-1], [1,1]]
elif direction == 1: targ_pts = [[-1,-1-magnitude], [-1,1], [1,-1], [1,1]]
elif direction == 2: targ_pts = [[-1,-1], [-1-magnitude,1], [1,-1], [1,1]]
elif direction == 3: targ_pts = [[-1,-1], [-1,1+magnitude], [1,-1], [1,1]]
elif direction == 4: targ_pts = [[-1,-1], [-1,1], [1+magnitude,-1], [1,1]]
elif direction == 5: targ_pts = [[-1,-1], [-1,1], [1,-1-magnitude], [1,1]]
elif direction == 6: targ_pts = [[-1,-1], [-1,1], [1,-1], [1+magnitude,1]]
elif direction == 7: targ_pts = [[-1,-1], [-1,1], [1,-1], [1,1+magnitude]]
coeffs = _find_coeffs(targ_pts, _orig_pts) if invert else _find_coeffs(_orig_pts, targ_pts)
return _apply_perspective(c, coeffs)
skew = TfmCoord(_skew)
def get_transforms(do_flip:bool=True, flip_vert:bool=False, max_rotate:float=10., max_zoom:float=1.1,
max_lighting:float=0.2, max_warp:float=0.2, p_affine:float=0.75,
p_lighting:float=0.75, xtra_tfms:Optional[Collection[Transform]]=None)->Collection[Transform]:
"Utility func to easily create a list of flip, rotate, `zoom`, warp, lighting transforms."
res = [rand_crop()]
if do_flip: res.append(dihedral_affine() if flip_vert else flip_lr(p=0.5))
if max_warp: res.append(symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine))
if max_rotate: res.append(rotate(degrees=(-max_rotate,max_rotate), p=p_affine))
if max_zoom>1: res.append(rand_zoom(scale=(1.,max_zoom), p=p_affine))
if max_lighting:
res.append(brightness(change=(0.5*(1-max_lighting), 0.5*(1+max_lighting)), p=p_lighting))
res.append(contrast(scale=(1-max_lighting, 1/(1-max_lighting)), p=p_lighting))
# train , valid
return (res + listify(xtra_tfms), [crop_pad()])
def _compute_zs_mat(sz:TensorImageSize, scale:float, squish:float,
invert:bool, row_pct:float, col_pct:float)->AffineMatrix:
"Utility routine to compute zoom/squish matrix."
orig_ratio = math.sqrt(sz[1]/sz[0])
for s,r,i in zip(scale,squish, invert):
s,r = 1/math.sqrt(s),math.sqrt(r)
if s * r <= 1 and s / r <= 1: #Test if we are completely inside the picture
w,h = (s/r, s*r) if i else (s*r,s/r)
col_c = (1-w) * (2*col_pct - 1)
row_c = (1-h) * (2*row_pct - 1)
return _get_zoom_mat(w, h, col_c, row_c)
#Fallback, hack to emulate a center crop without cropping anything yet.
if orig_ratio > 1: return _get_zoom_mat(1/orig_ratio**2, 1, 0, 0.)
else: return _get_zoom_mat(1, orig_ratio**2, 0, 0.)
def _zoom_squish(c, scale:uniform=1.0, squish:uniform=1.0, invert:rand_bool=False,
row_pct:uniform=0.5, col_pct:uniform=0.5):
#This is intended for scale, squish and invert to be of size 10 (or whatever) so that the transform
#can try a few zoom/squishes before falling back to center crop (like torchvision.RandomResizedCrop)
m = _compute_zs_mat(c.size, scale, squish, invert, row_pct, col_pct)
return _affine_mult(c, FloatTensor(m))
zoom_squish = TfmCoord(_zoom_squish)
def rand_resize_crop(size:int, max_scale:float=2., ratios:Tuple[float,float]=(0.75,1.33)):
"Randomly resize and crop the image to a ratio in `ratios` after a zoom of `max_scale`."
return [zoom_squish(scale=(1.,max_scale,8), squish=(*ratios,8), invert=(0.5,8), row_pct=(0.,1.), col_pct=(0.,1.)),
crop(size=size)]
================================================
FILE: fastai/vision/tta.py
================================================
"Brings TTA (Test Time Functionality) to the `Learner` class. Use `learner.TTA()` instead"
from ..torch_core import *
from ..basic_train import *
from ..basic_train import _loss_func2activ
from ..basic_data import DatasetType
from .transform import *
__all__ = []
def _tta_only(learn:Learner, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, scale:float=1.35) -> Iterator[List[Tensor]]:
"Computes the outputs for several augmented inputs for TTA"
dl = learn.dl(ds_type)
ds = dl.dataset
old = ds.tfms
activ = ifnone(activ, _loss_func2activ(learn.loss_func))
augm_tfm = [o for o in learn.data.train_ds.tfms if o.tfm not in
(crop_pad, flip_lr, dihedral, zoom)]
try:
pbar = master_bar(range(8))
for i in pbar:
row = 1 if i&1 else 0
col = 1 if i&2 else 0
flip = i&4
d = {'row_pct':row, 'col_pct':col, 'is_random':False}
tfm = [*augm_tfm, zoom(scale=scale, **d), crop_pad(**d)]
if flip: tfm.append(flip_lr(p=1.))
ds.tfms = tfm
yield get_preds(learn.model, dl, pbar=pbar, activ=activ)[0]
finally: ds.tfms = old
Learner.tta_only = _tta_only
def _TTA(learn:Learner, beta:float=0.4, scale:float=1.35, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False) -> Tensors:
"Applies TTA to predict on `ds_type` dataset."
preds,y = learn.get_preds(ds_type, activ=activ)
all_preds = list(learn.tta_only(ds_type=ds_type, activ=activ, scale=scale))
avg_preds = torch.stack(all_preds).mean(0)
if beta is None: return preds,avg_preds,y
else:
final_preds = preds*beta + avg_preds*(1-beta)
if with_loss:
with NoneReduceOnCPU(learn.loss_func) as lf: loss = lf(final_preds, y)
return final_preds, y, loss
return final_preds, y
Learner.TTA = _TTA
================================================
FILE: fastai/widgets/__init__.py
================================================
from .class_confusion import *
from .image_cleaner import *
from .image_downloader import *
================================================
FILE: fastai/widgets/class_confusion.py
================================================
import math
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from itertools import permutations
from ..tabular import TabularDataBunch
from ..train import ClassificationInterpretation
import ipywidgets as widgets
class ClassConfusion():
"Plot the most confused datapoints and statistics for the models misses."
def __init__(self, interp:ClassificationInterpretation, classlist:list,
is_ordered:bool=False, cut_off:int=100, varlist:list=None,
figsize:tuple=(8,8)):
self.interp = interp
self._is_tab = isinstance(interp.learn.data, TabularDataBunch)
if self._is_tab:
if interp.learn.data.train_ds.x.cont_names != []:
for x in range(len(interp.learn.data.procs)):
if "Normalize" in str(interp.learn.data.procs[x]):
self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means
self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds
self.is_ordered = is_ordered
self.cut_off = cut_off
self.figsize = figsize
self.varlist = varlist
self.classl = classlist
self._show_losses(classlist)
def _show_losses(self, classl:list, **kwargs):
"Checks if the model is for Tabular or Images and gathers top losses"
_, self.tl_idx = self.interp.top_losses(len(self.interp.losses))
self._tab_losses() if self._is_tab else self._create_tabs()
def _create_tabs(self):
"Creates a tab for each variable"
self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2))
if self._is_tab:
self._boxes = len(self.df_list)
self._cols = math.ceil(math.sqrt(self._boxes))
self._rows = math.ceil(self._boxes/self._cols)
self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist
else:
vals = self.interp.most_confused()
self._ranges = []
self.tbnames = []
self._boxes = int(input('Please enter a value for `k`, or the top images you will see: '))
for x in iter(vals):
for y in range(len(self.lis)):
if x[0:2] == self.lis[y]:
self._ranges.append(x[2])
self.tbnames.append(str(x[0] + ' | ' + x[1]))
items = [widgets.Output() for i, tab in enumerate(self.tbnames)]
self.tabs = widgets.Tab()
self.tabs.children = items
for i in range(len(items)):
self.tabs.set_title(i, self.tbnames[i])
self._populate_tabs()
def _populate_tabs(self):
"Adds relevant graphs to each tab"
with tqdm(total=len(self.tbnames)) as pbar:
for i, tab in enumerate(self.tbnames):
with self.tabs.children[i]:
self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i)
pbar.update(1)
display(self.tabs)
def _plot_tab(self, tab:str):
"Generates graphs"
if self._boxes is not None:
fig, ax = plt.subplots(self._boxes, figsize=self.figsize)
else:
fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize)
fig.subplots_adjust(hspace=.5)
for j, x in enumerate(self.df_list):
title = f'{"".join(x.columns[-1])} {tab} distribution'
if self._boxes is None:
row = int(j / self._cols)
col = j % row
if tab in self.cat_names:
vals = pd.value_counts(x[tab].values)
if self._boxes is not None:
if vals.nunique() < 10:
fig = vals.plot(kind='bar', title=title, ax=ax[j], rot=0, width=.75)
elif vals.nunique() > self.cut_off:
print(f'Number of values is above {self.cut_off}')
else:
fig = vals.plot(kind='barh', title=title, ax=ax[j], width=.75)
else:
fig = vals.plot(kind='barh', title=title, ax=ax[row, col], width=.75)
else:
vals = x[tab]
if self._boxes is not None:
axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency')
else:
axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency')
axs.set_ylabel('Frequency')
if len(set(vals)) > 1:
vals.plot(kind='kde', ax=axs, title=title, secondary_y=True)
else:
print('Less than two unique values, cannot graph the KDE')
plt.show(fig)
plt.tight_layout()
def _plot_imgs(self, tab:str, i:int ,**kwargs):
"Plots the most confused images"
classes_gnd = self.interp.data.classes
x = 0
if self._ranges[i] < self._boxes:
cols = math.ceil(math.sqrt(self._ranges[i]))
rows = math.ceil(self._ranges[i]/cols)
if self._ranges[i] < 4 or self._boxes < 4:
cols = 2
rows = 2
else:
cols = math.ceil(math.sqrt(self._boxes))
rows = math.ceil(self._boxes/cols)
fig, ax = plt.subplots(rows, cols, figsize=self.figsize)
[axi.set_axis_off() for axi in ax.ravel()]
for j, idx in enumerate(self.tl_idx):
if self._boxes < x+1 or x > self._ranges[i]:
break
da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]
row = (int)(x / cols)
col = x % cols
if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]:
img, lbl = self.interp.data.valid_ds[idx]
fn = self.interp.data.valid_ds.x.items[idx]
fn = re.search('([^/*]+)_\d+.*$', str(fn)).group(0)
img.show(ax=ax[row, col])
ax[row,col].set_title(fn)
x += 1
plt.show(fig)
plt.tight_layout()
def _tab_losses(self, **kwargs):
"Gathers dataframes of the combinations data"
classes = self.interp.data.classes
cat_names = self.interp.data.x.cat_names
cont_names = self.interp.data.x.cont_names
comb = self.classl if self.is_ordered else list(permutations(self.classl,2))
self.df_list = []
arr = []
for i, idx in enumerate(self.tl_idx):
da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx]
res = ''
for c, n in zip(da.cats, da.names[:len(da.cats)]):
string = f'{da.classes[n][c]}'
if string == 'True' or string == 'False':
string += ';'
res += string
else:
string = string[1:]
res += string + ';'
for c, n in zip(da.conts, da.names[len(da.cats):]):
res += f'{c:.4f};'
arr.append(res)
f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)
for i, var in enumerate(self.interp.data.cont_names):
f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])
f['Original'] = 'Original'
self.df_list.append(f)
for j, x in enumerate(comb):
arr = []
for i, idx in enumerate(self.tl_idx):
da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]
cl = int(cl)
if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]:
res = ''
for c, n in zip(da.cats, da.names[:len(da.cats)]):
string = f'{da.classes[n][c]}'
if string == 'True' or string == 'False':
string += ';'
res += string
else:
string = string[1:]
res += string + ';'
for c, n in zip(da.conts, da.names[len(da.cats):]):
res += f'{c:.4f};'
arr.append(res)
f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)
for i, var in enumerate(self.interp.data.cont_names):
f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])
f[str(x)] = str(x)
self.df_list.append(f)
self.cat_names = cat_names
self._create_tabs()
================================================
FILE: fastai/widgets/image_cleaner.py
================================================
from ..torch_core import *
from ..basic_train import *
from ..basic_data import *
from ..vision.data import *
from ..vision.transform import *
from ..vision.image import *
from ..callbacks.hooks import *
from ..layers import *
from ipywidgets import widgets, Layout
from IPython.display import clear_output, display
__all__ = ['DatasetFormatter', 'ImageCleaner']
class DatasetFormatter():
"Returns a dataset with the appropriate format and file indices to be displayed."
@classmethod
def from_toplosses(cls, learn, n_imgs=None, **kwargs):
"Gets indices with top losses."
train_ds, train_idxs = cls.get_toplosses_idxs(learn, n_imgs, **kwargs)
return train_ds, train_idxs
@classmethod
def get_toplosses_idxs(cls, learn, n_imgs, **kwargs):
"Sorts `ds_type` dataset by top losses and returns dataset and sorted indices."
dl = learn.data.fix_dl
if not n_imgs: n_imgs = len(dl.dataset)
_,_,top_losses = learn.get_preds(ds_type=DatasetType.Fix, with_loss=True)
idxs = torch.topk(top_losses, n_imgs)[1]
return cls.padded_ds(dl.dataset, **kwargs), idxs
def padded_ds(ll_input, size=(250, 300), resize_method=ResizeMethod.CROP, padding_mode='zeros', **kwargs):
"For a LabelList `ll_input`, resize each image to `size` using `resize_method` and `padding_mode`."
return ll_input.transform(tfms=crop_pad(), size=size, resize_method=resize_method, padding_mode=padding_mode)
@classmethod
def from_similars(cls, learn, layer_ls:list=[0, 7, 2], **kwargs):
"Gets the indices for the most similar images."
train_ds, train_idxs = cls.get_similars_idxs(learn, layer_ls, **kwargs)
return train_ds, train_idxs
@classmethod
def get_similars_idxs(cls, learn, layer_ls, **kwargs):
"Gets the indices for the most similar images in `ds_type` dataset"
hook = hook_output(learn.model[layer_ls[0]][layer_ls[1]][layer_ls[2]])
dl = learn.data.fix_dl
ds_actns = cls.get_actns(learn, hook=hook, dl=dl, **kwargs)
similarities = cls.comb_similarity(ds_actns, ds_actns, **kwargs)
idxs = cls.sort_idxs(similarities)
return cls.padded_ds(dl, **kwargs), idxs
@staticmethod
def get_actns(learn, hook:Hook, dl:DataLoader, pool=AdaptiveConcatPool2d, pool_dim:int=4, **kwargs):
"Gets activations at the layer specified by `hook`, applies `pool` of dim `pool_dim` and concatenates"
print('Getting activations...')
actns = []
learn.model.eval()
with torch.no_grad():
for (xb,yb) in progress_bar(dl):
learn.model(xb)
actns.append((hook.stored).cpu())
if pool:
pool = pool(pool_dim)
return pool(torch.cat(actns)).view(len(dl.x),-1)
else: return torch.cat(actns).view(len(dl.x),-1)
@staticmethod
def comb_similarity(t1: torch.Tensor, t2: torch.Tensor, **kwargs):
# https://github.com/pytorch/pytorch/issues/11202
"Computes the similarity function between each embedding of `t1` and `t2` matrices."
print('Computing similarities...')
w1 = t1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if t2 is t1 else t2.norm(p=2, dim=1, keepdim=True)
t = torch.mm(t1, t2.t()) / (w1 * w2.t()).clamp(min=1e-8)
return torch.tril(t, diagonal=-1)
def largest_indices(arr, n):
"Returns the `n` largest indices from a numpy array `arr`."
#https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
flat = arr.flatten()
indices = np.argpartition(flat, -n)[-n:]
indices = indices[np.argsort(-flat[indices])]
return np.unravel_index(indices, arr.shape)
@classmethod
def sort_idxs(cls, similarities):
"Sorts `similarities` and return the indexes in pairs ordered by highest similarity."
idxs = cls.largest_indices(similarities, len(similarities))
idxs = [(idxs[0][i], idxs[1][i]) for i in range(len(idxs[0]))]
return [e for l in idxs for e in l]
class ImageCleaner():
"Displays images for relabeling or deletion and saves changes in `path` as 'cleaned.csv'."
def __init__(self, dataset, fns_idxs, path, batch_size:int=5, duplicates=False):
self._all_images,self._batch = [],[]
self._path = Path(path)
self._batch_size = batch_size
if duplicates: self._batch_size = 2
self._duplicates = duplicates
self._labels = dataset.classes
self._all_images = self.create_image_list(dataset, fns_idxs)
self._csv_dict = {dataset.x.items[i]: dataset.y[i] for i in range(len(dataset))}
self._deleted_fns = []
self._skipped = 0
self.render()
@classmethod
def make_img_widget(cls, img, layout=Layout(), format='jpg'):
"Returns an image widget for specified file name `img`."
return widgets.Image(value=img, format=format, layout=layout)
@classmethod
def make_button_widget(cls, label, file_path=None, handler=None, style=None, layout=Layout(width='auto')):
"Return a Button widget with specified `handler`."
btn = widgets.Button(description=label, layout=layout)
if handler is not None: btn.on_click(handler)
if style is not None: btn.button_style = style
btn.file_path = file_path
btn.flagged_for_delete = False
return btn
@classmethod
def make_dropdown_widget(cls, description='Description', options=['Label 1', 'Label 2'], value='Label 1',
file_path=None, layout=Layout(), handler=None):
"Return a Dropdown widget with specified `handler`."
dd = widgets.Dropdown(description=description, options=options, value=value, layout=layout)
if file_path is not None: dd.file_path = file_path
if handler is not None: dd.observe(handler, names=['value'])
return dd
@classmethod
def make_horizontal_box(cls, children, layout=Layout()):
"Make a horizontal box with `children` and `layout`."
return widgets.HBox(children, layout=layout)
@classmethod
def make_vertical_box(cls, children, layout=Layout(), duplicates=False):
"Make a vertical box with `children` and `layout`."
if not duplicates: return widgets.VBox(children, layout=layout)
else: return widgets.VBox([children[0], children[2]], layout=layout)
def create_image_list(self, dataset, fns_idxs):
"Create a list of images, filenames and labels but first removing files that are not supposed to be displayed."
items = dataset.x.items
if self._duplicates:
chunked_idxs = chunks(fns_idxs, 2)
chunked_idxs = [chunk for chunk in chunked_idxs if Path(items[chunk[0]]).is_file() and Path(items[chunk[1]]).is_file()]
return [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for chunk in chunked_idxs for i in chunk]
else:
return [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for i in fns_idxs if
Path(items[i]).is_file()]
def relabel(self, change):
"Relabel images by moving from parent dir with old label `class_old` to parent dir with new label `class_new`."
class_new,class_old,file_path = change.new,change.old,change.owner.file_path
fp = Path(file_path)
parent = fp.parents[1]
self._csv_dict[fp] = class_new
def next_batch(self, _):
"Handler for 'Next Batch' button click. Delete all flagged images and renders next batch."
for img_widget, delete_btn, fp, in self._batch:
fp = delete_btn.file_path
if (delete_btn.flagged_for_delete == True):
self.delete_image(fp)
self._deleted_fns.append(fp)
self._all_images = self._all_images[self._batch_size:]
self.empty_batch()
self.render()
def on_delete(self, btn):
"Flag this image as delete or keep."
btn.button_style = "" if btn.flagged_for_delete else "danger"
btn.flagged_for_delete = not btn.flagged_for_delete
def empty_batch(self): self._batch[:] = []
def delete_image(self, file_path):
del self._csv_dict[file_path]
def empty(self):
return len(self._all_images) == 0
def get_widgets(self, duplicates):
"Create and format widget set."
widgets = []
for (img,fp,human_readable_label) in self._all_images[:self._batch_size]:
img_widget = self.make_img_widget(img, layout=Layout(height='250px', width='300px'))
dropdown = self.make_dropdown_widget(description='', options=self._labels, value=human_readable_label,
file_path=fp, handler=self.relabel, layout=Layout(width='auto'))
delete_btn = self.make_button_widget('Delete', file_path=fp, handler=self.on_delete)
widgets.append(self.make_vertical_box([img_widget, dropdown, delete_btn],
layout=Layout(width='auto', height='300px',
overflow_x="hidden"), duplicates=duplicates))
self._batch.append((img_widget, delete_btn, fp))
return widgets
def batch_contains_deleted(self):
"Check if current batch contains already deleted images."
if not self._duplicates: return False
imgs = [self._all_images[:self._batch_size][0][1], self._all_images[:self._batch_size][1][1]]
return any(img in self._deleted_fns for img in imgs)
def write_csv(self):
# Get first element's file path so we write CSV to same directory as our data
csv_path = self._path/'cleaned.csv'
with open(csv_path, 'w') as f:
csv_writer = csv.writer(f)
csv_writer.writerow(['name','label'])
for pair in self._csv_dict.items():
pair = [os.path.relpath(pair[0], self._path), pair[1]]
csv_writer.writerow(pair)
return csv_path
def render(self):
"Re-render Jupyter cell for batch of images."
clear_output()
self.write_csv()
if self.empty() and self._skipped>0:
return display(f'No images to show :). {self._skipped} pairs were '
f'skipped since at least one of the images was deleted by the user.')
elif self.empty():
return display('No images to show :)')
if self.batch_contains_deleted():
self.next_batch(None)
self._skipped += 1
else:
display(self.make_horizontal_box(self.get_widgets(self._duplicates)))
display(self.make_button_widget('Next Batch', handler=self.next_batch, style="primary"))
================================================
FILE: fastai/widgets/image_downloader.py
================================================
from ..core import *
from ..vision.data import *
from ipywidgets import widgets, Layout, Output, HBox, VBox, Text, BoundedIntText, Button, Dropdown, Box
from IPython.display import clear_output, display
from urllib.parse import quote
from bs4 import BeautifulSoup
import time
__all__ = ['ImageDownloader', 'download_google_images']
_img_sizes = {'>400*300':'isz:lt,islt:qsvga','>640*480':'isz:lt,islt:vga','>800*600':'isz:lt,islt:svga',
'>1024*768':'visz:lt,islt:xga', '>2MP':'isz:lt,islt:2mp','>4MP':'isz:lt,islt:4mp','>6MP':'isz:lt,islt:6mp',
'>8MP':'isz:lt,islt:8mp', '>10MP':'isz:lt,islt:10mp','>12MP':'isz:lt,islt:12mp','>15MP':'isz:lt,islt:15mp',
'>20MP':'isz:lt,islt:20mp','>40MP':'isz:lt,islt:40mp','>70MP':'isz:lt,islt:70mp'}
class ImageDownloader():
"""
Displays a widget that allows searching and downloading images from google images search
in a Jupyter Notebook or Lab.
"""
def __init__(self, path:Union[Path,str]='data'):
"Setup path to save images to, init the UI, and render the widgets."
self._path = Path(path)
self._ui = self._init_ui()
self.render()
def _init_ui(self) -> VBox:
"Initialize the widget UI and return the UI."
self._search_input = Text(placeholder="What images to search for?")
self._count_input = BoundedIntText(placeholder="How many pics?", value=10, min=1, max=5000, step=1,
layout=Layout(width='60px'))
self._size_input = Dropdown(options= _img_sizes.keys(), value='>400*300', layout=Layout(width='120px'))
self._download_button = Button(description="Search & Download", icon="download", layout=Layout(width='200px'))
self._download_button.on_click(self.on_download_button_click)
self._output = Output()
self._controls_pane = HBox([self._search_input, self._count_input, self._size_input, self._download_button],
layout=Layout(width='auto', height='40px'))
self._heading = ""
self._download_complete_heading = "Download complete. Here are a few images
"
self._preview_header = widgets.HTML(self._heading, layout=Layout(height='60px'))
self._img_pane = Box(layout=Layout(display='inline'))
return VBox([self._controls_pane, self._preview_header, self._img_pane])
def render(self) -> None:
clear_output()
display(self._ui)
def clear_imgs(self) -> None:
"Clear the widget's images preview pane."
self._preview_header.value = self._heading
self._img_pane.children = tuple()
def validate_search_input(self) -> bool:
"Check if input value is empty."
input = self._search_input
if input.value == str(): input.layout = Layout(border="solid 2px red", height='auto')
else: self._search_input.layout = Layout()
return input.value != str()
def on_download_button_click(self, btn) -> None:
"Download button click handler: validate search term and download images."
term = self._search_input.value
limit = int(self._count_input.value)
size = self._size_input.value
if not self.validate_search_input(): return
self.clear_imgs()
downloaded_images = download_google_images(self._path, term, n_images=limit, size=size)
self.display_images_widgets(downloaded_images[:min(limit, 12)])
self._preview_header.value = self._download_complete_heading
self.render()
def display_images_widgets(self, fnames:list) -> None:
"Display a few preview images in the notebook"
imgs = [widgets.Image(value=open(f, 'rb').read(), width='200px') for f in fnames]
self._img_pane.children = tuple(imgs)
def download_google_images(path:PathOrStr, search_term:str, size:str='>400*300', n_images:int=10, format:str='jpg',
max_workers:int=defaults.cpus, timeout:int=4) -> FilePathList:
"""
Search for `n_images` images on Google, matching `search_term` and `size` requirements,
download them into `path`/`search_term` and verify them, using `max_workers` threads.
"""
label_path = Path(path)/search_term
search_url = _search_url(search_term, size=size, format=format)
if n_images <= 100: img_tuples = _fetch_img_tuples(search_url, format=format, n_images=n_images)
else: img_tuples = _fetch_img_tuples_webdriver(search_url, format=format, n_images=n_images)
downloaded_images = _download_images(label_path, img_tuples, max_workers=max_workers, timeout=timeout)
if len(downloaded_images) == 0: raise RuntimeError(f"Couldn't download any images.")
verify_images(label_path, max_workers=max_workers)
return get_image_files(label_path)
def _url_params(size:str='>400*300', format:str='jpg') -> str:
"Build Google Images Search Url params and return them as a string."
_fmts = {'jpg':'ift:jpg','gif':'ift:gif','png':'ift:png','bmp':'ift:bmp', 'svg':'ift:svg','webp':'webp','ico':'ift:ico'}
if size not in _img_sizes:
raise RuntimeError(f"""Unexpected size argument value: {size}.
See `widgets.image_downloader._img_sizes` for supported sizes.""")
if format not in _fmts:
raise RuntimeError(f"Unexpected image file format: {format}. Use jpg, gif, png, bmp, svg, webp, or ico.")
return "&tbs=" + _img_sizes[size] + "," + _fmts[format]
def _search_url(search_term:str, size:str='>400*300', format:str='jpg') -> str:
"Return a Google Images Search URL for a given search term."
return ('https://www.google.com/search?q=' + quote(search_term) +
'&espv=2&biw=1366&bih=667&site=webhp&source=lnms&tbm=isch' +
_url_params(size, format) + '&sa=X&ei=XosDVaCXD8TasATItgE&ved=0CAcQ_AUoAg')
def _img_fname(img_url:str) -> str:
"Return image file name including the extension given its url."
return img_url.split('/')[-1]
def _fetch_img_tuples(url:str, format:str='jpg', n_images:int=10) -> list:
"Parse the Google Images Search for urls and return the image metadata as tuples (fname, url)."
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36'}
html = requests.get(url, headers=headers).text
return _html_to_img_tuples(html, format=format, n_images=n_images)
def _html_to_img_tuples(html:str, format:str='jpg', n_images:int=10) -> list:
"Parse the google images html to img tuples containining `(fname, url)`"
bs = BeautifulSoup(html, 'html.parser')
img_tags = bs.find_all('div', {'class': 'rg_meta'})
metadata_dicts = (json.loads(e.text) for e in img_tags)
img_tuples = ((_img_fname(d['ou']), d['ou']) for d in metadata_dicts if d['ity'] == format)
return list(itertools.islice(img_tuples, n_images))
def _fetch_img_tuples_webdriver(url:str, format:str='jpg', n_images:int=150) -> list:
"""
Parse the Google Images Search for urls and return the image metadata as tuples (fname, url).
Use this for downloads of >100 images. Requires `selenium`.
"""
try:
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
except:
print("""Looks like you're trying to download > 100 images and `selenium`
is not installed. Try running `pip install selenium` to fix this.
You'll also need chrome and `chromedriver` installed.""")
options = webdriver.ChromeOptions()
options.add_argument("--headless")
try: driver = webdriver.Chrome(chrome_options=options)
except: print("""Error initializing chromedriver.
Check if it's in your path by running `which chromedriver`""")
driver.set_window_size(1440, 900)
driver.get(url)
for i in range(n_images // 100 + 1):
driver.execute_script("window.scrollTo(0, document.body.scrollHeight)")
time.sleep(0.5 + random.random()/2.0)
n_available = len(driver.find_elements_by_css_selector("div.rg_meta"))
if n_available < n_images:
raise ValueError(f"Requested {n_images} images, but only found {n_available}.")
html = driver.page_source
driver.close()
return _html_to_img_tuples(html, format=format, n_images=n_images)
def _download_images(label_path:PathOrStr, img_tuples:list, max_workers:int=defaults.cpus, timeout:int=4) -> FilePathList:
"""
Downloads images in `img_tuples` to `label_path`.
If the directory doesn't exist, it'll be created automatically.
Uses `parallel` to speed things up in `max_workers` when the system has enough CPU cores.
If something doesn't work, try setting up `max_workers=0` to debug.
"""
os.makedirs(Path(label_path), exist_ok=True)
parallel( partial(_download_single_image, label_path, timeout=timeout), img_tuples, max_workers=max_workers)
return get_image_files(label_path)
def _download_single_image(label_path:Path, img_tuple:tuple, i:int, timeout:int=4) -> None:
"""
Downloads a single image from Google Search results to `label_path`
given an `img_tuple` that contains `(fname, url)` of an image to download.
`i` is just an iteration number `int`.
"""
suffix = re.findall(r'\.\w+?(?=(?:\?|$))', img_Tuple[1])
suffix = suffix[0].lower() if len(suffix)>0 else '.jpg'
fname = f"{i:08d}{suffix}"
download_url(img_Tuple[1], label_path/fname, timeout=timeout)
================================================
FILE: fid/LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: fid/fid_score.py
================================================
#!/usr/bin/env python3
# Code adapted and modified from https://github.com/mseitzer/pytorch-fid. Licensing
# and description duplicated below.
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
The FID metric calculates the distance between two distributions of images.
Typically, we have summary statistics (mean & covariance matrix) of one
of these distributions, while the 2nd distribution is given by a GAN.
When run as a stand-alone program, it compares the distribution of
images that are stored as PNG/JPEG at a specified location with a
distribution given by summary statistics (in pickle format).
The FID is calculated by assuming that X_1 and X_2 are the activations of
the pool_3 layer of the inception net for generated samples and real world
samples respectively.
See --help to see further details.
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
of Tensorflow
Copyright 2018 Institute of Bioinformatics, JKU Linz
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import pathlib
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import numpy as np
import torch
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
import cv2
import imageio
try:
from tqdm import tqdm
except ImportError:
# If not tqdm is not available, provide a mock version of it
def tqdm(x):
return x
from .inception import InceptionV3
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument(
'path',
type=str,
nargs=2,
help=('Path to the generated images or ' 'to .npz statistic files'),
)
parser.add_argument('--batch-size', type=int, default=50, help='Batch size to use')
parser.add_argument(
'--dims',
type=int,
default=2048,
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
help=(
'Dimensionality of Inception features to use. '
'By default, uses pool3 features'
),
)
parser.add_argument(
'-c', '--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)'
)
def load_image_resized(fn, sz):
return cv2.resize(
imageio.imread(str(fn)), dsize=(sz, sz), interpolation=cv2.INTER_CUBIC
).astype(np.float32)
def get_activations(
files,
model,
batch_size=50,
dims=2048,
cuda=False,
verbose=False,
eval_size: int = 299,
):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- files : List of image files paths
-- model : Instance of inception model
-- batch_size : Batch size of images for the model to process at once.
Make sure that the number of samples is a multiple of
the batch size, otherwise some samples are ignored. This
behavior is retained to match the original FID score
implementation.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the number
of calculated batches is reported.
Returns:
-- A numpy array of dimension (num images, dims) that contains the
activations of the given tensor when feeding inception with the
query tensor.
"""
model.eval()
if len(files) % batch_size != 0:
print(
(
'Warning: number of images is not a multiple of the '
'batch size. Some samples are going to be ignored.'
)
)
if batch_size > len(files):
print(
(
'Warning: batch size is bigger than the data size. '
'Setting batch size to data size'
)
)
batch_size = len(files)
n_batches = len(files) // batch_size
n_used_imgs = n_batches * batch_size
pred_arr = np.empty((n_used_imgs, dims))
for i in tqdm(range(n_batches)):
if verbose:
print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
start = i * batch_size
end = start + batch_size
images = np.array(
[load_image_resized(fn, eval_size) for fn in files[start:end]]
)
# images = np.array([imageio.imread(str(f)).astype(np.float32)
# for f in files[start:end]])
# Reshape to (n_images, 3, height, width)
images = images.transpose((0, 3, 1, 2))
images /= 255
batch = torch.from_numpy(images).type(torch.FloatTensor)
if cuda:
batch = batch.cuda()
pred = model(batch)[0]
# If model output is not scalar, apply global spatial average pooling.
# This happens if you choose a dimensionality not equal 2048.
if pred.shape[2] != 1 or pred.shape[3] != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
if verbose:
print(' done')
return pred_arr
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representative data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representative data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), 'Training and test mean vectors have different lengths'
assert (
sigma1.shape == sigma2.shape
), 'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
'fid calculation produces singular product; '
'adding %s to diagonal of cov estimates'
) % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
def calculate_activation_statistics(
files, model, batch_size=50, dims=2048, cuda=False, verbose=False
):
"""Calculation of the statistics used by the FID.
Params:
-- files : List of image files paths
-- model : Instance of inception model
-- batch_size : The images numpy array is split into batches with
batch size batch_size. A reasonable batch size
depends on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the
number of calculated batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the inception model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the inception model.
"""
act = get_activations(files, model, batch_size, dims, cuda, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
if path.endswith('.npz'):
f = np.load(path)
m, s = f['mu'][:], f['sigma'][:]
f.close()
else:
path = pathlib.Path(path)
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
m, s = calculate_activation_statistics(files, model, batch_size, dims, cuda)
return m, s
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
"""Calculates the FID of two paths"""
for p in paths:
if not os.path.exists(p):
raise RuntimeError('Invalid path: %s' % p)
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
if cuda:
model.cuda()
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda)
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda)
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
if __name__ == '__main__':
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
fid_value = calculate_fid_given_paths(
args.path, args.batch_size, args.gpu != '', args.dims
)
print('FID: ', fid_value)
================================================
FILE: fid/inception.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
class InceptionV3(nn.Module):
"""Pretrained InceptionV3 network returning feature maps"""
# Index of default block of inception to return,
# corresponds to output of final average pooling
DEFAULT_BLOCK_INDEX = 3
# Maps feature dimensionality to their output blocks indices
BLOCK_INDEX_BY_DIM = {
64: 0, # First max pooling features
192: 1, # Second max pooling featurs
768: 2, # Pre-aux classifier features
2048: 3, # Final average pooling features
}
def __init__(
self,
output_blocks=[DEFAULT_BLOCK_INDEX],
resize_input=True,
normalize_input=True,
requires_grad=False,
use_fid_inception=True,
):
"""Build pretrained InceptionV3
Parameters
----------
output_blocks : list of int
Indices of blocks to return features of. Possible values are:
- 0: corresponds to output of first max pooling
- 1: corresponds to output of second max pooling
- 2: corresponds to output which is fed to aux classifier
- 3: corresponds to output of final average pooling
resize_input : bool
If true, bilinearly resizes input to width and height 299 before
feeding input to model. As the network without fully connected
layers is fully convolutional, it should be able to handle inputs
of arbitrary size, so resizing might not be strictly needed
normalize_input : bool
If true, scales the input from range (0, 1) to the range the
pretrained Inception network expects, namely (-1, 1)
requires_grad : bool
If true, parameters of the model require gradients. Possibly useful
for finetuning the network
use_fid_inception : bool
If true, uses the pretrained Inception model used in Tensorflow's
FID implementation. If false, uses the pretrained Inception model
available in torchvision. The FID Inception model has different
weights and a slightly different structure from torchvision's
Inception model. If you want to compute FID scores, you are
strongly advised to set this parameter to true to get comparable
results.
"""
super(InceptionV3, self).__init__()
self.resize_input = resize_input
self.normalize_input = normalize_input
self.output_blocks = sorted(output_blocks)
self.last_needed_block = max(output_blocks)
assert self.last_needed_block <= 3, 'Last possible output block index is 3'
self.blocks = nn.ModuleList()
if use_fid_inception:
inception = fid_inception_v3()
else:
inception = models.inception_v3(pretrained=True)
# Block 0: input to maxpool1
block0 = [
inception.Conv2d_1a_3x3,
inception.Conv2d_2a_3x3,
inception.Conv2d_2b_3x3,
nn.MaxPool2d(kernel_size=3, stride=2),
]
self.blocks.append(nn.Sequential(*block0))
# Block 1: maxpool1 to maxpool2
if self.last_needed_block >= 1:
block1 = [
inception.Conv2d_3b_1x1,
inception.Conv2d_4a_3x3,
nn.MaxPool2d(kernel_size=3, stride=2),
]
self.blocks.append(nn.Sequential(*block1))
# Block 2: maxpool2 to aux classifier
if self.last_needed_block >= 2:
block2 = [
inception.Mixed_5b,
inception.Mixed_5c,
inception.Mixed_5d,
inception.Mixed_6a,
inception.Mixed_6b,
inception.Mixed_6c,
inception.Mixed_6d,
inception.Mixed_6e,
]
self.blocks.append(nn.Sequential(*block2))
# Block 3: aux classifier to final avgpool
if self.last_needed_block >= 3:
block3 = [
inception.Mixed_7a,
inception.Mixed_7b,
inception.Mixed_7c,
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
]
self.blocks.append(nn.Sequential(*block3))
for param in self.parameters():
param.requires_grad = requires_grad
def forward(self, inp):
"""Get Inception feature maps
Parameters
----------
inp : torch.autograd.Variable
Input tensor of shape Bx3xHxW. Values are expected to be in
range (0, 1)
Returns
-------
List of torch.autograd.Variable, corresponding to the selected output
block, sorted ascending by index
"""
outp = []
x = inp
if self.resize_input:
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
if self.normalize_input:
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
for idx, block in enumerate(self.blocks):
x = block(x)
if idx in self.output_blocks:
outp.append(x)
if idx == self.last_needed_block:
break
return outp
def fid_inception_v3():
"""Build pretrained Inception model for FID computation
The Inception model for FID computation uses a different set of weights
and has a slightly different structure than torchvision's Inception.
This method first constructs torchvision's Inception and then patches the
necessary parts that are different in the FID Inception model.
"""
inception = models.inception_v3(
num_classes=1008, aux_logits=False, pretrained=False
)
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
inception.Mixed_7b = FIDInceptionE_1(1280)
inception.Mixed_7c = FIDInceptionE_2(2048)
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
inception.load_state_dict(state_dict)
return inception
class FIDInceptionA(models.inception.InceptionA):
"""InceptionA block patched for FID computation"""
def __init__(self, in_channels, pool_features):
super(FIDInceptionA, self).__init__(in_channels, pool_features)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionC(models.inception.InceptionC):
"""InceptionC block patched for FID computation"""
def __init__(self, in_channels, channels_7x7):
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionE_1(models.inception.InceptionE):
"""First InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_1, self).__init__(in_channels)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
# Patch: Tensorflow's average pool does not use the padded zero's in
# its average calculation
branch_pool = F.avg_pool2d(
x, kernel_size=3, stride=1, padding=1, count_include_pad=False
)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class FIDInceptionE_2(models.inception.InceptionE):
"""Second InceptionE block patched for FID computation"""
def __init__(self, in_channels):
super(FIDInceptionE_2, self).__init__(in_channels)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
# Patch: The FID Inception model uses max pooling instead of average
# pooling. This is likely an error in this specific Inception
# implementation, as other Inception models use average pooling here
# (which matches the description in the paper).
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
================================================
FILE: models/.gitkeep
================================================
================================================
FILE: requirements-colab.txt
================================================
fastai==1.0.60
tensorboardX>=1.6
ffmpeg-python
yt-dlp
opencv-python>=4.2.0.32
Pillow
tornado
imgaug==0.2.6
================================================
FILE: requirements-dev.txt
================================================
black
pre-commit
================================================
FILE: requirements.txt
================================================
wandb
fastai==1.0.60
tensorboardX>=1.6
ffmpeg
ffmpeg-python
yt-dlp
jupyterlab
opencv-python>=4.2.0.32
Pillow==9.3.0
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.11.0
torchvision==0.12.0
ipywidgets
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
def get_description():
return "Deep Learning library for colorizing and restoring old images and video"
# def get_long_description():
# with open("README.md") as f:
# return f.read()
def get_requirements():
with open("requirements.txt") as f:
return f.read().splitlines()
setup(
name="DeOldify",
version="0.0.1",
packages=find_packages(exclude=["tests"]),
url="https://github.com/jantic/DeOldify",
license="MIT License",
description=get_description(),
# long_description=get_long_description(),
# long_description_content_type="text/markdown",
classifiers=[
"Development Status :: 4 - Beta",
"Framework :: Jupyter",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
],
install_requires=get_requirements(),
python_requires=">=3.6",
)
================================================
FILE: test_images/.gitkeep
================================================
================================================
FILE: tox.ini
================================================
[tox]
envlist=static,format
skipsdist=True
[testenv]
whitelist_externals=
/usr/bin/sh
/usr/bin/test
[testenv:format]
deps=
black
commands=
black -S --check deoldify
[testenv:static]
deps=
-rrequirements.txt
pylint
commands=
sh -c 'pylint --disable=W deoldify; test $(( $? & (1|2|4|32) )) = 0'