token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["
"]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab02/text_recognizer/__init__.py
================================================
"""Modules for creating and running a text recognizer."""
================================================
FILE: lab02/text_recognizer/data/__init__.py
================================================
"""Module containing submodules for each dataset.
Each dataset is defined as a class in that submodule.
The datasets should have a .config method that returns
any configuration information needed by the model.
Most datasets define their constants in a submodule
of the metadata module that is parallel to this one in the
hierarchy.
"""
from .util import BaseDataset
from .base_data_module import BaseDataModule
from .mnist import MNIST
from .emnist import EMNIST
from .emnist_lines import EMNISTLines
================================================
FILE: lab02/text_recognizer/data/base_data_module.py
================================================
"""Base DataModule class."""
import argparse
import os
from pathlib import Path
from typing import Collection, Dict, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch.utils.data import ConcatDataset, DataLoader
from text_recognizer import util
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.shared as metadata
def load_and_print_info(data_module_class) -> None:
"""Load EMNISTLines and print info."""
parser = argparse.ArgumentParser()
data_module_class.add_to_argparse(parser)
args = parser.parse_args()
dataset = data_module_class(args)
dataset.prepare_data()
dataset.setup()
print(dataset)
def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return filename
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename
BATCH_SIZE = 128
NUM_AVAIL_CPUS = len(os.sched_getaffinity(0))
NUM_AVAIL_GPUS = torch.cuda.device_count()
# sensible multiprocessing defaults: at most one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS
# but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS
class BaseDataModule(pl.LightningDataModule):
"""Base for all of our LightningDataModules.
Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html
"""
def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS)
self.on_gpu = isinstance(self.args.get("gpus", None), (str, int))
# Make sure to set the variables below in subclasses
self.input_dims: Tuple[int, ...]
self.output_dims: Tuple[int, ...]
self.mapping: Collection
self.data_train: Union[BaseDataset, ConcatDataset]
self.data_val: Union[BaseDataset, ConcatDataset]
self.data_test: Union[BaseDataset, ConcatDataset]
@classmethod
def data_dirname(cls):
return metadata.DATA_DIRNAME
@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size",
type=int,
default=BATCH_SIZE,
help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.",
)
parser.add_argument(
"--num_workers",
type=int,
default=DEFAULT_NUM_WORKERS,
help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.",
)
return parser
def config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping}
def prepare_data(self, *args, **kwargs) -> None:
"""Take the first steps to prepare data for use.
Use this method to do things that might write to disk or that need to be done only from a single GPU
in distributed settings (so don't set state `self.x = y`).
"""
def setup(self, stage: Optional[str] = None) -> None:
"""Perform final setup to prepare data for consumption by DataLoader.
Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
def train_dataloader(self):
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def val_dataloader(self):
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def test_dataloader(self):
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
================================================
FILE: lab02/text_recognizer/data/emnist.py
================================================
"""EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present."""
import json
import os
from pathlib import Path
import shutil
from typing import Sequence
import zipfile
import h5py
import numpy as np
import toml
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset, split_dataset
import text_recognizer.metadata.emnist as metadata
from text_recognizer.stems.image import ImageStem
from text_recognizer.util import temporary_working_directory
NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS
RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME
METADATA_FILENAME = metadata.METADATA_FILENAME
DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class.
TRAIN_FRAC = 0.8
class EMNIST(BaseDataModule):
"""EMNIST dataset of handwritten characters and digits.
"The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
def __init__(self, args=None):
super().__init__(args)
self.mapping = metadata.MAPPING
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
self.transform = ImageStem()
self.input_dims = metadata.DIMS
self.output_dims = metadata.OUTPUT_DIMS
def prepare_data(self, *args, **kwargs) -> None:
if not os.path.exists(PROCESSED_DATA_FILENAME):
_download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_trainval = f["x_train"][:]
self.y_trainval = f["y_train"][:].squeeze().astype(int)
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self):
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n"
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
def _download_and_process_emnist():
metadata = toml.load(METADATA_FILENAME)
_download_raw_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path):
print("Unzipping EMNIST...")
with temporary_working_directory(dirname):
with zipfile.ZipFile(filename, "r") as zf:
zf.extract("matlab/emnist-byclass.mat")
from scipy.io import loadmat
# NOTE: If importing at the top of module, would need to list scipy as prod dependency.
print("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
# NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices
if SAMPLE_TO_BALANCE:
print("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
print("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
print("Saving essential dataset parameters to text_recognizer/data...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(list(mapping.values()))
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
with open(ESSENTIALS_FILENAME, "w") as f:
json.dump(essentials, f)
print("Cleaning up...")
shutil.rmtree("matlab")
def _sample_to_balance(x, y):
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(42)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_inds = []
for label in np.unique(y.flatten()):
inds = np.where(y == label)[0]
sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
all_sampled_inds.append(sampled_inds)
ind = np.concatenate(all_sampled_inds)
x_sampled = x[ind]
y_sampled = y[ind]
return x_sampled, y_sampled
def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset
iam_characters = [
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
# Also add special tokens:
# - CTC blank token at index 0
# - Start token at index 1
# - End token at index 2
# - Padding token at index 3
# NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this!
return ["", " ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab02/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab02/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab02/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab02/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab02/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab02/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
================================================
FILE: lab02/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab02/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab02/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab02/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab02/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab02/training/__init__.py
================================================
================================================
FILE: lab02/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab02/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: lab03/notebooks/lab01_pytorch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab03/notebooks/lab03_transformers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" \", \"\")\n",
"\n",
"idx = random.randint(0, len(xs))\n",
"\n",
"print(show(ys[idx]))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4dT3UCNzTsoc"
},
"source": [
"The `ResnetTransformer` model can run on this data\n",
"if passed the `.config`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXL-vIGRr86D"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMxa-oWyT01E"
},
"source": [
"Our models are now big enough\n",
"that we want to make use of GPU acceleration\n",
"as much as we can,\n",
"even when working on single inputs,\n",
"so let's cast to the GPU if we have one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-YyUM8LgvW0w"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"rnt.to(device); xs = xs.to(device); ys = ys.to(device);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-E3UdD4zUJi"
},
"source": [
"First, let's just pass it through the ResNet encoder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-LUUtlvaxrvg"
},
"outputs": [],
"source": [
"resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n",
" # resnet is designed for RGB images, so we replicate the input across channels 3 times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eimgJ5dnywjg"
},
"outputs": [],
"source": [
"resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n",
"plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n",
"plt.axis(\"off\"); plt.colorbar(fraction=0.05);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These embeddings, though generated by random, untrained weights,\n",
"are not entirely useless.\n",
"\n",
"Before neural networks could be effectively\n",
"trained end to end,\n",
"they were often used with frozen random weights\n",
"eveywhere except the final layer\n",
"(see e.g.\n",
"[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n",
"[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n",
"these methods were still competitive, and\n",
"[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n",
"provide a\n",
"[theoretical basis](https://arxiv.org/abs/2011.14522)\n",
"for understanding their performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ye6pW0ETzw2A"
},
"source": [
"The final result, though, is repetitive gibberish --\n",
"at the bare minimum, we need to train the unembedding/readout layer\n",
"in order to get reasonable text."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our architecture includes randomization with dropout,\n",
"so repeated runs of the cell below will generate different outcomes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu3Pa7gLsFMo"
},
"outputs": [],
"source": [
"preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gvCXUbskv6XM"
},
"outputs": [],
"source": [
"print(show(preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without teacher forcing, runtime is also variable from iteration to iteration --\n",
"the model stops when it generates an \"end sequence\" or padding token,\n",
"which is not deterministic thanks to the dropout layers.\n",
"For similar reasons, runtime is variable across inputs.\n",
"\n",
"The variable runtime of autoregressive generation\n",
"is also not great for scaling.\n",
"In a distributed setting, as required for large scale,\n",
"forward passes need to be synced across devices,\n",
"and if one device is generating a batch of much longer sequences,\n",
"it will cause all the others to idle while they wait on it to finish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t76MSVRXV0V7"
},
"source": [
"Let's turn our model into a `TransformerLitModel`\n",
"so we can run with teacher forcing.\n",
"\n",
"> You may be wondering:\n",
" why isn't teacher forcing part of the PyTorch module?\n",
" In general, the `LightningModule`\n",
" should encapsulate things that are needed in training, validation, and testing\n",
" but not during inference.\n",
" The teacher forcing trick fits this paradigm,\n",
" even though it's so critical to what makes Transformers powerful. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qrHRKHowdDi"
},
"outputs": [],
"source": [
"import text_recognizer.lit_models\n",
"\n",
"lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlNaFqR50Oid"
},
"source": [
"Now we can use `.teacher_forward` if we also provide the target `ys`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lpZdqXS5wn0F"
},
"outputs": [],
"source": [
"forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Zx9SmsN0QLT"
},
"source": [
"This may not run faster than the `rnt.forward`,\n",
"since generations are always the maximum possible length,\n",
"but runtimes and output lengths are deterministic and constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tu-XNYpi0Qvi"
},
"source": [
"Forcing doesn't necessarily make our predictions better.\n",
"They remain highly repetitive gibberish."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JcEgify9w0sv"
},
"outputs": [],
"source": [
"forcing_preds = torch.argmax(forcing_outs, dim=0)\n",
"\n",
"print(show(forcing_preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xn6GGNzc9a3o"
},
"source": [
"## Training the `ResNetTransformer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uvZYsuSyWUXe"
},
"source": [
"We're finally ready to train this model on full paragraphs of handwritten text!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cJwC7b720Sd"
},
"source": [
"This is a more serious model --\n",
"it's the one we use in the\n",
"[deployed TextRecognizer application](http://fsdl.me/app).\n",
"It's much larger than the models we've seen this far,\n",
"so it can easily outstrip available compute resources,\n",
"in particular GPU memory.\n",
"\n",
"To help, we use\n",
"[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n",
"which shrinks the size of most of our floats by half,\n",
"which reduces memory consumption and can speed up computation.\n",
"\n",
"If your GPU has less than 8GB of available RAM,\n",
"you'll see a \"CUDA out of memory\" `RuntimeError`,\n",
"which is something of a\n",
"[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n",
"In this case, you can resolve it by reducing the `--batch_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w1mXlhfy04Nm"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"gpus = int(torch.cuda.is_available())\n",
"\n",
"if gpus:\n",
" !nvidia-smi\n",
"else:\n",
" print(\"watch out! working with this model on a typical CPU is not feasible\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "os1vW1rPZ1dy"
},
"source": [
"Even with an okay GPU, like a\n",
"[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n",
"a single epoch of training can take over 10 minutes to run.\n",
"We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n",
"but you can remove those flags to see what full training looks like."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vnF6dWFn4JlZ"
},
"source": [
"It can take a long time (overnight)\n",
"to train this model to decent performance on a single GPU,\n",
"so we'll focus on other pieces for the exercises.\n",
"\n",
"> At the time of writing in mid-2022, the cheapest readily available option\n",
"for training this model to decent performance on this dataset with this codebase\n",
"comes out around $10, using\n",
"[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n",
"See, for example,\n",
"[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n",
"and associated experiment.\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HufjdUZN0t4l",
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"# above %%magic times the cell, useful as a poor man's profiler\n",
"\n",
"%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n",
" --gpus={gpus} --batch_size 16 --precision 16 \\\n",
" --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L6fQ93ju3Iku"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udb1Ekjx3L63"
},
"source": [
"### 🌟 Try out gradient accumulation and other \"training tricks\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpqViB4p3Wfb"
},
"source": [
"Larger batches are helpful not only for increasing parallelization\n",
"and amortizing fixed costs\n",
"but also for getting more reliable gradients.\n",
"Larger batches give gradients with less noise\n",
"and to a point, less gradient noise means faster convergence.\n",
"\n",
"But larger batches result in larger tensors,\n",
"which take up more GPU memory,\n",
"a resource that is tightly constrained\n",
"and device-dependent.\n",
"\n",
"Does that mean we are limited in the quality of our gradients\n",
"due to our machine size?\n",
"\n",
"Not entirely:\n",
"look up the `--accumulate_grad_batches`\n",
"argument to the `pl.Trainer`.\n",
"You should be able to understand why\n",
"it makes it possible to compute the same gradients\n",
"you would find for a batch of size `k * N`\n",
"on a machine that can only run batches up to size `N`.\n",
"\n",
"Accumulating gradients across batches is among the\n",
"[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n",
"Try some of them out!\n",
"Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2vtkmX830y3"
},
"source": [
"### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n",
"\n",
"While training this model to actually fit the whole dataset is infeasible\n",
"as a short exercise on commodity hardware,\n",
"it's practical to train this model to memorize a batch of 16 examples.\n",
"\n",
"Passing `--overfit_batches 1` flag limits the number of training batches to 1\n",
"and turns off\n",
"[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n",
"so that in each epoch, the model just sees the same single batch of data over and over again.\n",
"\n",
"At first, try training the model to a loss of `2.5` --\n",
"it should be doable in 100 epochs or less,\n",
"which is just a few minutes on a commodity GPU.\n",
"\n",
"Once you've got that working,\n",
"crank up the number of epochs by a factor of 10\n",
"and confirm that the loss continues to go down.\n",
"\n",
"Some tips:\n",
"\n",
"- Use `--limit_test_batches 0` to turn off testing.\n",
"We don't need it because we don't care about generalization\n",
"and it's relatively slow because it runs the model autoregressively.\n",
"\n",
"- Use `--help` and look through the model class args\n",
"to find the arguments used to reduce model size.\n",
"\n",
"- By default, there's lots of regularization to prevent overfitting.\n",
"Look through the args for the model class and data class\n",
"for regularization knobs to turn off or down."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab03_transformers.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab03/text_recognizer/__init__.py
================================================
"""Modules for creating and running a text recognizer."""
================================================
FILE: lab03/text_recognizer/data/__init__.py
================================================
"""Module containing submodules for each dataset.
Each dataset is defined as a class in that submodule.
The datasets should have a .config method that returns
any configuration information needed by the model.
Most datasets define their constants in a submodule
of the metadata module that is parallel to this one in the
hierarchy.
"""
from .util import BaseDataset
from .base_data_module import BaseDataModule
from .mnist import MNIST
from .emnist import EMNIST
from .emnist_lines import EMNISTLines
from .iam_paragraphs import IAMParagraphs
================================================
FILE: lab03/text_recognizer/data/base_data_module.py
================================================
"""Base DataModule class."""
import argparse
import os
from pathlib import Path
from typing import Collection, Dict, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch.utils.data import ConcatDataset, DataLoader
from text_recognizer import util
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.shared as metadata
def load_and_print_info(data_module_class) -> None:
"""Load EMNISTLines and print info."""
parser = argparse.ArgumentParser()
data_module_class.add_to_argparse(parser)
args = parser.parse_args()
dataset = data_module_class(args)
dataset.prepare_data()
dataset.setup()
print(dataset)
def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return filename
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename
BATCH_SIZE = 128
NUM_AVAIL_CPUS = len(os.sched_getaffinity(0))
NUM_AVAIL_GPUS = torch.cuda.device_count()
# sensible multiprocessing defaults: at most one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS
# but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS
class BaseDataModule(pl.LightningDataModule):
"""Base for all of our LightningDataModules.
Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html
"""
def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS)
self.on_gpu = isinstance(self.args.get("gpus", None), (str, int))
# Make sure to set the variables below in subclasses
self.input_dims: Tuple[int, ...]
self.output_dims: Tuple[int, ...]
self.mapping: Collection
self.data_train: Union[BaseDataset, ConcatDataset]
self.data_val: Union[BaseDataset, ConcatDataset]
self.data_test: Union[BaseDataset, ConcatDataset]
@classmethod
def data_dirname(cls):
return metadata.DATA_DIRNAME
@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size",
type=int,
default=BATCH_SIZE,
help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.",
)
parser.add_argument(
"--num_workers",
type=int,
default=DEFAULT_NUM_WORKERS,
help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.",
)
return parser
def config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping}
def prepare_data(self, *args, **kwargs) -> None:
"""Take the first steps to prepare data for use.
Use this method to do things that might write to disk or that need to be done only from a single GPU
in distributed settings (so don't set state `self.x = y`).
"""
def setup(self, stage: Optional[str] = None) -> None:
"""Perform final setup to prepare data for consumption by DataLoader.
Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
def train_dataloader(self):
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def val_dataloader(self):
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def test_dataloader(self):
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
================================================
FILE: lab03/text_recognizer/data/emnist.py
================================================
"""EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present."""
import json
import os
from pathlib import Path
import shutil
from typing import Sequence
import zipfile
import h5py
import numpy as np
import toml
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset, split_dataset
import text_recognizer.metadata.emnist as metadata
from text_recognizer.stems.image import ImageStem
from text_recognizer.util import temporary_working_directory
NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS
RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME
METADATA_FILENAME = metadata.METADATA_FILENAME
DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class.
TRAIN_FRAC = 0.8
class EMNIST(BaseDataModule):
"""EMNIST dataset of handwritten characters and digits.
"The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
def __init__(self, args=None):
super().__init__(args)
self.mapping = metadata.MAPPING
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
self.transform = ImageStem()
self.input_dims = metadata.DIMS
self.output_dims = metadata.OUTPUT_DIMS
def prepare_data(self, *args, **kwargs) -> None:
if not os.path.exists(PROCESSED_DATA_FILENAME):
_download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_trainval = f["x_train"][:]
self.y_trainval = f["y_train"][:].squeeze().astype(int)
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self):
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n"
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
def _download_and_process_emnist():
metadata = toml.load(METADATA_FILENAME)
_download_raw_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path):
print("Unzipping EMNIST...")
with temporary_working_directory(dirname):
with zipfile.ZipFile(filename, "r") as zf:
zf.extract("matlab/emnist-byclass.mat")
from scipy.io import loadmat
# NOTE: If importing at the top of module, would need to list scipy as prod dependency.
print("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
# NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices
if SAMPLE_TO_BALANCE:
print("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
print("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
print("Saving essential dataset parameters to text_recognizer/data...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(list(mapping.values()))
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
with open(ESSENTIALS_FILENAME, "w") as f:
json.dump(essentials, f)
print("Cleaning up...")
shutil.rmtree("matlab")
def _sample_to_balance(x, y):
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(42)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_inds = []
for label in np.unique(y.flatten()):
inds = np.where(y == label)[0]
sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
all_sampled_inds.append(sampled_inds)
ind = np.concatenate(all_sampled_inds)
x_sampled = x[ind]
y_sampled = y[ind]
return x_sampled, y_sampled
def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset
iam_characters = [
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
# Also add special tokens:
# - CTC blank token at index 0
# - Start token at index 1
# - End token at index 2
# - Padding token at index 3
# NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this!
return ["", " ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab03/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab03/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" "]
self.ignore_tokens = [self.start_index, self.end_index, self.padding_index]
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
================================================
FILE: lab03/text_recognizer/lit_models/metrics.py
================================================
"""Special-purpose metrics for tracking our model performance."""
from typing import Sequence
import torch
import torchmetrics
class CharacterErrorRate(torchmetrics.CharErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore
preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()]
targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()]
super().update(preds_l, targets_l)
def test_character_error_rate():
metric = CharacterErrorRate([0, 1])
X = torch.tensor(
[
[0, 2, 2, 3, 3, 1], # error will be 0
[0, 2, 1, 1, 1, 1], # error will be .75
[0, 2, 2, 4, 4, 1], # error will be .5
]
)
Y = torch.tensor(
[
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
]
)
metric(X, Y)
assert metric.compute() == sum([0, 0.75, 0.5]) / 3
if __name__ == "__main__":
test_character_error_rate()
================================================
FILE: lab03/text_recognizer/lit_models/transformer.py
================================================
"""An encoder-decoder Transformer model"""
from typing import List, Sequence
import torch
from .base import BaseImageToTextLitModel
from .util import replace_after
class TransformerLitModel(BaseImageToTextLitModel):
"""
Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module.
The module must implement an encode and decode method, and the forward method
should be the forward pass during production inference.
"""
def __init__(self, model, args=None):
super().__init__(model, args)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index)
def forward(self, x):
return self.model(x)
def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Uses provided sequence y as guide for non-autoregressive encoding-decoding of x.
Parameters
----------
x
Batch of images to be encoded. See self.model.encode for shape information.
y
Batch of ground truth output sequences.
Returns
-------
torch.Tensor
(B, C, Sy) logits
"""
x = self.model.encode(x)
output = self.model.decode(x, y) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("train/loss", loss)
outputs = {"loss": loss}
return outputs
def validation_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("validation/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True)
return outputs
def test_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("test/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True)
return outputs
def map(self, ks: Sequence[int], ignore: bool = True) -> str:
"""Maps an iterable of integers to a string using the lit model's mapping."""
if ignore:
return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens])
else:
return "".join([self.mapping[k] for k in ks])
def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]:
"""Maps a list of lists of integers to a list of strings using the lit model's mapping."""
return [self.map(k, ignore) for k in ks]
def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor:
"""Converts logit-like Tensors into prediction indices, optionally overwritten after end token index.
Parameters
----------
logitlikes
(B, C, Sy) Tensor with classes as second dimension. The largest value is the one
whose index we will return. Logits, logprobs, and probs are all acceptable.
replace_after_end
Whether to replace values after the first appearance of the end token with the padding token.
Returns
-------
torch.Tensor
(B, Sy) Tensor of integers in [0, C-1] representing predictions.
"""
raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy)
if replace_after_end:
return replace_after(raw, self.end_index, self.padding_index) # (B, Sy)
else:
return raw # (B, Sy)
================================================
FILE: lab03/text_recognizer/lit_models/util.py
================================================
from typing import Union
import torch
def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:
"""Return indices of first appearance of element in x, collapsing along dim.
Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9
Parameters
----------
x
One or two-dimensional Tensor to search for element.
element
Item to search for inside x.
dim
Dimension of Tensor to collapse over.
Returns
-------
torch.Tensor
Indices where element occurs in x. If element is not found,
return length of x along dim. One dimension smaller than x.
Raises
------
ValueError
if x is not a 1 or 2 dimensional Tensor
Examples
--------
>>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)
tensor([2, 1, 3, 0])
>>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)
tensor(0)
"""
if x.dim() > 2 or x.dim() == 0:
raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}")
matches = x == element
first_appearance_mask = (matches.cumsum(dim) == 1) & matches
does_match, match_index = first_appearance_mask.max(dim)
first_inds = torch.where(does_match, match_index, x.shape[dim])
return first_inds
def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor:
"""Replace all values in each row of 2d Tensor x after the first appearance of element with replace.
Parameters
----------
x
Two-dimensional Tensor (shape denoted (B, S)) to replace values in.
element
Item to search for inside x.
replace
Item that replaces entries that appear after element.
Returns
-------
outs
New Tensor of same shape as x with values after element replaced.
Examples
--------
>>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4)
tensor([[1, 2, 3],
[2, 3, 4],
[1, 1, 1],
[3, 4, 4]])
"""
first_appearances = first_appearance(x, element, dim=1) # (B,)
indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,)
outs = torch.where(
indices[None, :] <= first_appearances[:, None], # if index is before first appearance
x, # return the value from x
replace, # otherwise, return the replacement value
)
return outs # (B, S)
================================================
FILE: lab03/text_recognizer/metadata/emnist.py
================================================
from pathlib import Path
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
NUM_SPECIAL_TOKENS = 4
INPUT_SHAPE = (28, 28)
DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor()
OUTPUT_DIMS = (1,)
MAPPING = [
"",
" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab03/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab03/text_recognizer/metadata/iam.py
================================================
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam"
EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be
LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates
================================================
FILE: lab03/text_recognizer/metadata/iam_paragraphs.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN]
IMAGE_SCALE_FACTOR = 2
IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH)
MAX_LABEL_LENGTH = 682
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1)
================================================
FILE: lab03/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab03/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab03/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
from .resnet_transformer import ResnetTransformer
================================================
FILE: lab03/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab03/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab03/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab03/text_recognizer/models/resnet_transformer.py
================================================
"""Model combining a ResNet with a Transformer for image-to-sequence tasks."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
import torchvision
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage
TF_DIM = 256
TF_FC_DIM = 1024
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
RESNET_DIM = 512 # hard-coded
class ResnetTransformer(nn.Module):
"""Pass an image through a Resnet and decode the resulting embedding with a Transformer."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
self.mapping = data_config["mapping"]
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# ## Encoder part - should output vector sequence of length self.dim per sample
resnet = torchvision.models.resnet18(weights=None)
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
# encoder_projection will output (B, dim, _H, _W) logits
self.enc_pos_encoder = PositionalEncodingImage(
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
) # Max (Ho, Wo)
# ## Decoder part
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Autoregressively produce sequences of labels from input images.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
output_tokens
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1] # Set the last output token
# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
break
# Set all tokens after end or padding token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if self.encoder_projection.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
(Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps
"""
_B, C, _H, _W = x.shape
if C == 1:
x = x.repeat(1, 3, 1, 1)
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
return x
def decode(self, x, y):
"""Decode a batch of encoded images x with guiding sequences y.
During autoregressive inference, the guiding sequence will be previous predictions.
During training, the guiding sequence will be the ground truth.
Parameters
----------
x
(Sx, B, E) images encoded as sequences of embeddings
y
(B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) batch of logit sequences
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.dec_pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab03/text_recognizer/models/transformer_util.py
================================================
"""Position Encoding and other utilities for Transformers."""
import math
import torch
from torch import Tensor
import torch.nn as nn
class PositionalEncodingImage(nn.Module):
"""
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
"""
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None:
super().__init__()
self.d_model = d_model
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
return pe
def forward(self, x: Tensor) -> Tensor:
"""pytorch.nn.module.forward"""
# x.shape = (B, d_model, H, W)
assert x.shape[1] == self.pe.shape[0] # type: ignore
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
return x
class PositionalEncoding(torch.nn.Module):
"""Classic Attention-is-all-you-need positional encoding."""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape = (S, B, d_model)
assert x.shape[2] == self.pe.shape[2] # type: ignore
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
"""Generate a triangular (size, size) mask."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
================================================
FILE: lab03/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab03/text_recognizer/stems/paragraph.py
================================================
"""IAMParagraphs Stem class."""
import torchvision.transforms as transforms
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.stems.image import ImageStem
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
class ParagraphStem(ImageStem):
"""A stem for handling images that contain a paragraph of text."""
def __init__(
self,
augment=False,
color_jitter_kwargs=None,
random_affine_kwargs=None,
random_perspective_kwargs=None,
gaussian_blur_kwargs=None,
sharpness_kwargs=None,
):
super().__init__()
if not augment:
self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)])
else:
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"shear": 6,
"scale": (0.95, 1),
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if random_perspective_kwargs is None:
random_perspective_kwargs = {
"distortion_scale": 0.2,
"p": 0.5,
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if gaussian_blur_kwargs is None:
gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
if sharpness_kwargs is None:
sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
# IMAGE_SHAPE is (576, 640)
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomCrop(
size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant"
),
transforms.RandomAffine(**random_affine_kwargs),
transforms.RandomPerspective(**random_perspective_kwargs),
transforms.GaussianBlur(**gaussian_blur_kwargs),
transforms.RandomAdjustSharpness(**sharpness_kwargs),
]
)
================================================
FILE: lab03/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab03/training/__init__.py
================================================
================================================
FILE: lab03/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.loss == "transformer":
lit_model_class = lit_models.TransformerLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/cer":
filename_format += "-validation.cer={validation/cer:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab03/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: lab04/notebooks/lab01_pytorch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab04/notebooks/lab03_transformers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" \", \"\")\n",
"\n",
"idx = random.randint(0, len(xs))\n",
"\n",
"print(show(ys[idx]))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4dT3UCNzTsoc"
},
"source": [
"The `ResnetTransformer` model can run on this data\n",
"if passed the `.config`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXL-vIGRr86D"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMxa-oWyT01E"
},
"source": [
"Our models are now big enough\n",
"that we want to make use of GPU acceleration\n",
"as much as we can,\n",
"even when working on single inputs,\n",
"so let's cast to the GPU if we have one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-YyUM8LgvW0w"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"rnt.to(device); xs = xs.to(device); ys = ys.to(device);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-E3UdD4zUJi"
},
"source": [
"First, let's just pass it through the ResNet encoder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-LUUtlvaxrvg"
},
"outputs": [],
"source": [
"resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n",
" # resnet is designed for RGB images, so we replicate the input across channels 3 times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eimgJ5dnywjg"
},
"outputs": [],
"source": [
"resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n",
"plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n",
"plt.axis(\"off\"); plt.colorbar(fraction=0.05);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These embeddings, though generated by random, untrained weights,\n",
"are not entirely useless.\n",
"\n",
"Before neural networks could be effectively\n",
"trained end to end,\n",
"they were often used with frozen random weights\n",
"eveywhere except the final layer\n",
"(see e.g.\n",
"[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n",
"[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n",
"these methods were still competitive, and\n",
"[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n",
"provide a\n",
"[theoretical basis](https://arxiv.org/abs/2011.14522)\n",
"for understanding their performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ye6pW0ETzw2A"
},
"source": [
"The final result, though, is repetitive gibberish --\n",
"at the bare minimum, we need to train the unembedding/readout layer\n",
"in order to get reasonable text."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our architecture includes randomization with dropout,\n",
"so repeated runs of the cell below will generate different outcomes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu3Pa7gLsFMo"
},
"outputs": [],
"source": [
"preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gvCXUbskv6XM"
},
"outputs": [],
"source": [
"print(show(preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without teacher forcing, runtime is also variable from iteration to iteration --\n",
"the model stops when it generates an \"end sequence\" or padding token,\n",
"which is not deterministic thanks to the dropout layers.\n",
"For similar reasons, runtime is variable across inputs.\n",
"\n",
"The variable runtime of autoregressive generation\n",
"is also not great for scaling.\n",
"In a distributed setting, as required for large scale,\n",
"forward passes need to be synced across devices,\n",
"and if one device is generating a batch of much longer sequences,\n",
"it will cause all the others to idle while they wait on it to finish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t76MSVRXV0V7"
},
"source": [
"Let's turn our model into a `TransformerLitModel`\n",
"so we can run with teacher forcing.\n",
"\n",
"> You may be wondering:\n",
" why isn't teacher forcing part of the PyTorch module?\n",
" In general, the `LightningModule`\n",
" should encapsulate things that are needed in training, validation, and testing\n",
" but not during inference.\n",
" The teacher forcing trick fits this paradigm,\n",
" even though it's so critical to what makes Transformers powerful. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qrHRKHowdDi"
},
"outputs": [],
"source": [
"import text_recognizer.lit_models\n",
"\n",
"lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlNaFqR50Oid"
},
"source": [
"Now we can use `.teacher_forward` if we also provide the target `ys`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lpZdqXS5wn0F"
},
"outputs": [],
"source": [
"forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Zx9SmsN0QLT"
},
"source": [
"This may not run faster than the `rnt.forward`,\n",
"since generations are always the maximum possible length,\n",
"but runtimes and output lengths are deterministic and constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tu-XNYpi0Qvi"
},
"source": [
"Forcing doesn't necessarily make our predictions better.\n",
"They remain highly repetitive gibberish."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JcEgify9w0sv"
},
"outputs": [],
"source": [
"forcing_preds = torch.argmax(forcing_outs, dim=0)\n",
"\n",
"print(show(forcing_preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xn6GGNzc9a3o"
},
"source": [
"## Training the `ResNetTransformer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uvZYsuSyWUXe"
},
"source": [
"We're finally ready to train this model on full paragraphs of handwritten text!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cJwC7b720Sd"
},
"source": [
"This is a more serious model --\n",
"it's the one we use in the\n",
"[deployed TextRecognizer application](http://fsdl.me/app).\n",
"It's much larger than the models we've seen this far,\n",
"so it can easily outstrip available compute resources,\n",
"in particular GPU memory.\n",
"\n",
"To help, we use\n",
"[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n",
"which shrinks the size of most of our floats by half,\n",
"which reduces memory consumption and can speed up computation.\n",
"\n",
"If your GPU has less than 8GB of available RAM,\n",
"you'll see a \"CUDA out of memory\" `RuntimeError`,\n",
"which is something of a\n",
"[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n",
"In this case, you can resolve it by reducing the `--batch_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w1mXlhfy04Nm"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"gpus = int(torch.cuda.is_available())\n",
"\n",
"if gpus:\n",
" !nvidia-smi\n",
"else:\n",
" print(\"watch out! working with this model on a typical CPU is not feasible\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "os1vW1rPZ1dy"
},
"source": [
"Even with an okay GPU, like a\n",
"[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n",
"a single epoch of training can take over 10 minutes to run.\n",
"We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n",
"but you can remove those flags to see what full training looks like."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vnF6dWFn4JlZ"
},
"source": [
"It can take a long time (overnight)\n",
"to train this model to decent performance on a single GPU,\n",
"so we'll focus on other pieces for the exercises.\n",
"\n",
"> At the time of writing in mid-2022, the cheapest readily available option\n",
"for training this model to decent performance on this dataset with this codebase\n",
"comes out around $10, using\n",
"[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n",
"See, for example,\n",
"[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n",
"and associated experiment.\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HufjdUZN0t4l",
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"# above %%magic times the cell, useful as a poor man's profiler\n",
"\n",
"%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n",
" --gpus={gpus} --batch_size 16 --precision 16 \\\n",
" --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L6fQ93ju3Iku"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udb1Ekjx3L63"
},
"source": [
"### 🌟 Try out gradient accumulation and other \"training tricks\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpqViB4p3Wfb"
},
"source": [
"Larger batches are helpful not only for increasing parallelization\n",
"and amortizing fixed costs\n",
"but also for getting more reliable gradients.\n",
"Larger batches give gradients with less noise\n",
"and to a point, less gradient noise means faster convergence.\n",
"\n",
"But larger batches result in larger tensors,\n",
"which take up more GPU memory,\n",
"a resource that is tightly constrained\n",
"and device-dependent.\n",
"\n",
"Does that mean we are limited in the quality of our gradients\n",
"due to our machine size?\n",
"\n",
"Not entirely:\n",
"look up the `--accumulate_grad_batches`\n",
"argument to the `pl.Trainer`.\n",
"You should be able to understand why\n",
"it makes it possible to compute the same gradients\n",
"you would find for a batch of size `k * N`\n",
"on a machine that can only run batches up to size `N`.\n",
"\n",
"Accumulating gradients across batches is among the\n",
"[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n",
"Try some of them out!\n",
"Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2vtkmX830y3"
},
"source": [
"### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n",
"\n",
"While training this model to actually fit the whole dataset is infeasible\n",
"as a short exercise on commodity hardware,\n",
"it's practical to train this model to memorize a batch of 16 examples.\n",
"\n",
"Passing `--overfit_batches 1` flag limits the number of training batches to 1\n",
"and turns off\n",
"[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n",
"so that in each epoch, the model just sees the same single batch of data over and over again.\n",
"\n",
"At first, try training the model to a loss of `2.5` --\n",
"it should be doable in 100 epochs or less,\n",
"which is just a few minutes on a commodity GPU.\n",
"\n",
"Once you've got that working,\n",
"crank up the number of epochs by a factor of 10\n",
"and confirm that the loss continues to go down.\n",
"\n",
"Some tips:\n",
"\n",
"- Use `--limit_test_batches 0` to turn off testing.\n",
"We don't need it because we don't care about generalization\n",
"and it's relatively slow because it runs the model autoregressively.\n",
"\n",
"- Use `--help` and look through the model class args\n",
"to find the arguments used to reduce model size.\n",
"\n",
"- By default, there's lots of regularization to prevent overfitting.\n",
"Look through the args for the model class and data class\n",
"for regularization knobs to turn off or down."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab03_transformers.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab04/notebooks/lab04_experiments.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab04/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab04/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" "]
self.ignore_tokens = [self.start_index, self.end_index, self.padding_index]
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
================================================
FILE: lab04/text_recognizer/lit_models/metrics.py
================================================
"""Special-purpose metrics for tracking our model performance."""
from typing import Sequence
import torch
import torchmetrics
class CharacterErrorRate(torchmetrics.CharErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore
preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()]
targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()]
super().update(preds_l, targets_l)
def test_character_error_rate():
metric = CharacterErrorRate([0, 1])
X = torch.tensor(
[
[0, 2, 2, 3, 3, 1], # error will be 0
[0, 2, 1, 1, 1, 1], # error will be .75
[0, 2, 2, 4, 4, 1], # error will be .5
]
)
Y = torch.tensor(
[
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
]
)
metric(X, Y)
assert metric.compute() == sum([0, 0.75, 0.5]) / 3
if __name__ == "__main__":
test_character_error_rate()
================================================
FILE: lab04/text_recognizer/lit_models/transformer.py
================================================
"""An encoder-decoder Transformer model"""
from typing import List, Sequence
import torch
from .base import BaseImageToTextLitModel
from .util import replace_after
class TransformerLitModel(BaseImageToTextLitModel):
"""
Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module.
The module must implement an encode and decode method, and the forward method
should be the forward pass during production inference.
"""
def __init__(self, model, args=None):
super().__init__(model, args)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index)
def forward(self, x):
return self.model(x)
def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Uses provided sequence y as guide for non-autoregressive encoding-decoding of x.
Parameters
----------
x
Batch of images to be encoded. See self.model.encode for shape information.
y
Batch of ground truth output sequences.
Returns
-------
torch.Tensor
(B, C, Sy) logits
"""
x = self.model.encode(x)
output = self.model.decode(x, y) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("train/loss", loss)
outputs = {"loss": loss}
if self.is_logged_batch():
preds = self.get_preds(logits)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs})
return outputs
def validation_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("validation/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def test_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("test/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def map(self, ks: Sequence[int], ignore: bool = True) -> str:
"""Maps an iterable of integers to a string using the lit model's mapping."""
if ignore:
return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens])
else:
return "".join([self.mapping[k] for k in ks])
def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]:
"""Maps a list of lists of integers to a list of strings using the lit model's mapping."""
return [self.map(k, ignore) for k in ks]
def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor:
"""Converts logit-like Tensors into prediction indices, optionally overwritten after end token index.
Parameters
----------
logitlikes
(B, C, Sy) Tensor with classes as second dimension. The largest value is the one
whose index we will return. Logits, logprobs, and probs are all acceptable.
replace_after_end
Whether to replace values after the first appearance of the end token with the padding token.
Returns
-------
torch.Tensor
(B, Sy) Tensor of integers in [0, C-1] representing predictions.
"""
raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy)
if replace_after_end:
return replace_after(raw, self.end_index, self.padding_index) # (B, Sy)
else:
return raw # (B, Sy)
================================================
FILE: lab04/text_recognizer/lit_models/util.py
================================================
from typing import Union
import torch
def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:
"""Return indices of first appearance of element in x, collapsing along dim.
Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9
Parameters
----------
x
One or two-dimensional Tensor to search for element.
element
Item to search for inside x.
dim
Dimension of Tensor to collapse over.
Returns
-------
torch.Tensor
Indices where element occurs in x. If element is not found,
return length of x along dim. One dimension smaller than x.
Raises
------
ValueError
if x is not a 1 or 2 dimensional Tensor
Examples
--------
>>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)
tensor([2, 1, 3, 0])
>>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)
tensor(0)
"""
if x.dim() > 2 or x.dim() == 0:
raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}")
matches = x == element
first_appearance_mask = (matches.cumsum(dim) == 1) & matches
does_match, match_index = first_appearance_mask.max(dim)
first_inds = torch.where(does_match, match_index, x.shape[dim])
return first_inds
def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor:
"""Replace all values in each row of 2d Tensor x after the first appearance of element with replace.
Parameters
----------
x
Two-dimensional Tensor (shape denoted (B, S)) to replace values in.
element
Item to search for inside x.
replace
Item that replaces entries that appear after element.
Returns
-------
outs
New Tensor of same shape as x with values after element replaced.
Examples
--------
>>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4)
tensor([[1, 2, 3],
[2, 3, 4],
[1, 1, 1],
[3, 4, 4]])
"""
first_appearances = first_appearance(x, element, dim=1) # (B,)
indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,)
outs = torch.where(
indices[None, :] <= first_appearances[:, None], # if index is before first appearance
x, # return the value from x
replace, # otherwise, return the replacement value
)
return outs # (B, S)
================================================
FILE: lab04/text_recognizer/metadata/emnist.py
================================================
from pathlib import Path
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
NUM_SPECIAL_TOKENS = 4
INPUT_SHAPE = (28, 28)
DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor()
OUTPUT_DIMS = (1,)
MAPPING = [
"",
" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab04/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab04/text_recognizer/metadata/iam.py
================================================
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam"
EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be
LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates
================================================
FILE: lab04/text_recognizer/metadata/iam_lines.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines"
IMAGE_SCALE_FACTOR = 2
CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate
IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR
IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (89, 1)
MAPPING = emnist.MAPPING
================================================
FILE: lab04/text_recognizer/metadata/iam_paragraphs.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN]
IMAGE_SCALE_FACTOR = 2
IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH)
MAX_LABEL_LENGTH = 682
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1)
================================================
FILE: lab04/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab04/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab04/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
from .resnet_transformer import ResnetTransformer
from .line_cnn_transformer import LineCNNTransformer
================================================
FILE: lab04/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab04/text_recognizer/models/line_cnn.py
================================================
"""Basic building blocks for convolutional models over lines of text."""
import argparse
import math
from typing import Any, Dict, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
# Common type hints
Param2D = Union[int, Tuple[int, int]]
CONV_DIM = 32
FC_DIM = 512
FC_DROPOUT = 0.2
WINDOW_WIDTH = 16
WINDOW_STRIDE = 8
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: Param2D = 3,
stride: Param2D = 1,
padding: Param2D = 1,
) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class LineCNN(nn.Module):
"""
Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits
"""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.args = vars(args) if args is not None else {}
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
_C, H, _W = data_config["input_dims"]
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
# Input is (1, H, W)
self.convs = nn.Sequential(
ConvBlock(1, conv_dim),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim, stride=2),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim * 2, stride=2),
ConvBlock(conv_dim * 2, conv_dim * 2),
ConvBlock(conv_dim * 2, conv_dim * 4, stride=2),
ConvBlock(conv_dim * 4, conv_dim * 4),
ConvBlock(
conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0
),
)
self.fc1 = nn.Linear(fc_dim, fc_dim)
self.dropout = nn.Dropout(fc_dropout)
self.fc2 = nn.Linear(fc_dim, self.num_classes)
self._init_weights()
def _init_weights(self):
"""
Initialize weights in a better way than default.
See https://github.com/pytorch/pytorch/issues/18182
"""
for m in self.modules():
if type(m) in {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}:
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the LineCNN to a black-and-white input image.
Parameters
----------
x
(B, 1, H, W) input image
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and self.window_width
C is self.num_classes
"""
_B, _C, _H, _W = x.shape
x = self.convs(x) # (B, FC_DIM, 1, Sx)
x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM)
x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM)
x = self.dropout(x)
x = self.fc2(x) # (B, S, C)
x = x.permute(0, 2, 1) # -> (B, C, S)
if self.limit_output_length:
x = x[:, :, : self.output_length]
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab04/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab04/text_recognizer/models/line_cnn_transformer.py
================================================
"""Model that combines a LineCNN with a Transformer model for text prediction."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .line_cnn import LineCNN
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding
TF_DIM = 256
TF_FC_DIM = 256
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
class LineCNNTransformer(nn.Module):
"""Process the line through a CNN and process the resulting sequence with a Transformer decoder."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# Instantiate LineCNN with "num_classes" set to self.dim
data_config_for_line_cnn = {**data_config}
data_config_for_line_cnn["mapping"] = list(range(self.dim))
self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args)
# LineCNN outputs (B, E, S) log probs, with E == dim
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.pos_encoder = PositionalEncoding(d_model=self.dim)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(Sx, B, E) logits
"""
x = self.line_cnn(x) # (B, E, Sx)
x = x * math.sqrt(self.dim)
x = x.permute(2, 0, 1) # (Sx, B, E)
x = self.pos_encoder(x) # (Sx, B, E)
return x
def decode(self, x, y):
"""Decode a batch of encoded images x using preceding ground truth y.
Parameters
----------
x
(Sx, B, E) image encoded as a sequence
y
(B, Sy) with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) logits
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict sequences of tokens from input images auto-regressively.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
# Set all tokens after end token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
@staticmethod
def add_to_argparse(parser):
LineCNN.add_to_argparse(parser)
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab04/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab04/text_recognizer/models/resnet_transformer.py
================================================
"""Model combining a ResNet with a Transformer for image-to-sequence tasks."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
import torchvision
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage
TF_DIM = 256
TF_FC_DIM = 1024
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
RESNET_DIM = 512 # hard-coded
class ResnetTransformer(nn.Module):
"""Pass an image through a Resnet and decode the resulting embedding with a Transformer."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
self.mapping = data_config["mapping"]
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# ## Encoder part - should output vector sequence of length self.dim per sample
resnet = torchvision.models.resnet18(weights=None)
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
# encoder_projection will output (B, dim, _H, _W) logits
self.enc_pos_encoder = PositionalEncodingImage(
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
) # Max (Ho, Wo)
# ## Decoder part
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Autoregressively produce sequences of labels from input images.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
output_tokens
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1] # Set the last output token
# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
break
# Set all tokens after end or padding token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if self.encoder_projection.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
(Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps
"""
_B, C, _H, _W = x.shape
if C == 1:
x = x.repeat(1, 3, 1, 1)
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
return x
def decode(self, x, y):
"""Decode a batch of encoded images x with guiding sequences y.
During autoregressive inference, the guiding sequence will be previous predictions.
During training, the guiding sequence will be the ground truth.
Parameters
----------
x
(Sx, B, E) images encoded as sequences of embeddings
y
(B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) batch of logit sequences
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.dec_pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab04/text_recognizer/models/transformer_util.py
================================================
"""Position Encoding and other utilities for Transformers."""
import math
import torch
from torch import Tensor
import torch.nn as nn
class PositionalEncodingImage(nn.Module):
"""
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
"""
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None:
super().__init__()
self.d_model = d_model
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
return pe
def forward(self, x: Tensor) -> Tensor:
"""pytorch.nn.module.forward"""
# x.shape = (B, d_model, H, W)
assert x.shape[1] == self.pe.shape[0] # type: ignore
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
return x
class PositionalEncoding(torch.nn.Module):
"""Classic Attention-is-all-you-need positional encoding."""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape = (S, B, d_model)
assert x.shape[2] == self.pe.shape[2] # type: ignore
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
"""Generate a triangular (size, size) mask."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
================================================
FILE: lab04/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab04/text_recognizer/stems/line.py
================================================
import random
from PIL import Image
from torchvision import transforms
import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.stems.image import ImageStem
class LineStem(ImageStem):
"""A stem for handling images containing a line of text."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.5, 1)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"translate": (0, 0.05),
"scale": (0.4, 1.1),
"shear": (-40, 50),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
if augment:
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
)
class IAMLineStem(ImageStem):
"""A stem for handling images containing lines of text from the IAMLines dataset."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
def embed_crop(crop, augment=augment):
# crop is PIL.image of dtype="L" (so values range from 0 -> 255)
image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
# Resize crop
crop_width, crop_height = crop.size
new_crop_height = metadata.IMAGE_HEIGHT
new_crop_width = int(new_crop_height * (crop_width / crop_height))
if augment:
# Add random stretching
new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR)
# Embed in the image
x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
y = metadata.IMAGE_HEIGHT - new_crop_height
image.paste(crop_resized, (x, y))
return image
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.8, 1.6)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 1,
"shear": (-30, 20),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
pil_transforms_list = [transforms.Lambda(embed_crop)]
if augment:
pil_transforms_list += [
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
self.pil_transforms = transforms.Compose(pil_transforms_list)
================================================
FILE: lab04/text_recognizer/stems/paragraph.py
================================================
"""IAMParagraphs Stem class."""
import torchvision.transforms as transforms
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.stems.image import ImageStem
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
class ParagraphStem(ImageStem):
"""A stem for handling images that contain a paragraph of text."""
def __init__(
self,
augment=False,
color_jitter_kwargs=None,
random_affine_kwargs=None,
random_perspective_kwargs=None,
gaussian_blur_kwargs=None,
sharpness_kwargs=None,
):
super().__init__()
if not augment:
self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)])
else:
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"shear": 6,
"scale": (0.95, 1),
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if random_perspective_kwargs is None:
random_perspective_kwargs = {
"distortion_scale": 0.2,
"p": 0.5,
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if gaussian_blur_kwargs is None:
gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
if sharpness_kwargs is None:
sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
# IMAGE_SHAPE is (576, 640)
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomCrop(
size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant"
),
transforms.RandomAffine(**random_affine_kwargs),
transforms.RandomPerspective(**random_perspective_kwargs),
transforms.GaussianBlur(**gaussian_blur_kwargs),
transforms.RandomAdjustSharpness(**sharpness_kwargs),
]
)
================================================
FILE: lab04/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab04/training/__init__.py
================================================
================================================
FILE: lab04/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import callbacks as cb
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--wandb",
action="store_true",
default=False,
help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.",
)
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.loss == "transformer":
lit_model_class = lit_models.TransformerLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/cer":
filename_format += "-validation.cer={validation/cer:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.wandb:
logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train")
logger.watch(model, log_freq=max(100, args.log_every_n_steps))
logger.log_hyperparams(vars(args))
experiment_dir = logger.experiment.dir
callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
if args.wandb and args.loss in ("transformer",):
callbacks.append(cb.ImageToTextLogger())
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
if args.wandb:
rank_zero_info("Best model also uploaded to W&B ")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab04/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: lab05/.flake8
================================================
[flake8]
select = ANN,B,B9,BLK,C,D,E,F,I,S,W
# only check selected error codes
max-complexity = 12
# C9 - flake8 McCabe Complexity checker -- threshold
max-line-length = 120
# E501 - flake8 -- line length too long, actually handled by black
extend-ignore =
# E W - flake8 PEP style check
E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks
# S - flake8-bandit safety check
S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password
# ANN - flake8-annotations type annotation check
ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some
# D1 - flake8-docstrings docstring style check
D100,D102,D103,D104,D105, # missing docstrings
# D2 D4 - flake8-docstrings docstring style check
D200,D205,D400,D401, # whitespace issues and first line content
# DAR - flake8-darglint docstring correctness check
DAR103, # mismatched or missing type in docstring
application-import-names = app_gradio,text_recognizer,tests,training
# flake8-import-order: which names are first party?
import-order-style = google
# flake8-import-order: which import order style guide do we use?
docstring-convention = numpy
# flake8-docstrings: which docstring style guide do we use?
strictness = short
# darglint: how "strict" are we with docstring completeness?
docstring-style = numpy
# darglint: which docstring style guide do we use?
suppress-none-returning = true
# flake8-annotations: do we allow un-annotated Nones in returns?
mypy-init-return = true
# flake8-annotations: do we allow init to have no return annotation?
per-file-ignores =
# list of case-by-case ignores, see files for details
*/__init__.py:F401,I
*/data/*.py:DAR
data/*.py:F,I
*text_recognizer/util.py:DAR101,F401
*training/run_experiment.py:I202
*app_gradio/app.py:I202
================================================
FILE: lab05/.github/workflows/pre-commit.yml
================================================
name: pre-commit
on:
pull_request:
push:
# allows this Action to be triggered manually
workflow_dispatch:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.10'
- uses: pre-commit/action@v3.0.0
================================================
FILE: lab05/.pre-commit-config.yaml
================================================
repos:
# a set of useful Python-based pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
# list of definitions and supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace # removes any whitespace at the ends of lines
- id: check-toml # check toml syntax by loading all toml files
- id: check-yaml # check yaml syntax by loading all yaml files
- id: check-json # check-json syntax by loading all json files
- id: check-merge-conflict # check for files with merge conflict strings
args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge
- id: check-added-large-files # check that no "large" files have been added
args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server
- id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.)
- id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.)
# black python autoformatting
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
# additional configuration of black in pyproject.toml
# flake8 python linter with all the fixins
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08)
additional_dependencies: [
flake8-bandit, flake8-bugbear, flake8-docstrings,
flake8-import-order, darglint, mypy, pycodestyle, pydocstyle]
args: ["--config", ".flake8"]
# additional configuration of flake8 and extensions in .flake8
# shellcheck-py for linting shell files
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.8.0.4
hooks:
- id: shellcheck
================================================
FILE: lab05/notebooks/lab01_pytorch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab05/notebooks/lab03_transformers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" \", \"\")\n",
"\n",
"idx = random.randint(0, len(xs))\n",
"\n",
"print(show(ys[idx]))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4dT3UCNzTsoc"
},
"source": [
"The `ResnetTransformer` model can run on this data\n",
"if passed the `.config`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXL-vIGRr86D"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMxa-oWyT01E"
},
"source": [
"Our models are now big enough\n",
"that we want to make use of GPU acceleration\n",
"as much as we can,\n",
"even when working on single inputs,\n",
"so let's cast to the GPU if we have one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-YyUM8LgvW0w"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"rnt.to(device); xs = xs.to(device); ys = ys.to(device);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-E3UdD4zUJi"
},
"source": [
"First, let's just pass it through the ResNet encoder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-LUUtlvaxrvg"
},
"outputs": [],
"source": [
"resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n",
" # resnet is designed for RGB images, so we replicate the input across channels 3 times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eimgJ5dnywjg"
},
"outputs": [],
"source": [
"resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n",
"plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n",
"plt.axis(\"off\"); plt.colorbar(fraction=0.05);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These embeddings, though generated by random, untrained weights,\n",
"are not entirely useless.\n",
"\n",
"Before neural networks could be effectively\n",
"trained end to end,\n",
"they were often used with frozen random weights\n",
"eveywhere except the final layer\n",
"(see e.g.\n",
"[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n",
"[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n",
"these methods were still competitive, and\n",
"[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n",
"provide a\n",
"[theoretical basis](https://arxiv.org/abs/2011.14522)\n",
"for understanding their performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ye6pW0ETzw2A"
},
"source": [
"The final result, though, is repetitive gibberish --\n",
"at the bare minimum, we need to train the unembedding/readout layer\n",
"in order to get reasonable text."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our architecture includes randomization with dropout,\n",
"so repeated runs of the cell below will generate different outcomes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu3Pa7gLsFMo"
},
"outputs": [],
"source": [
"preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gvCXUbskv6XM"
},
"outputs": [],
"source": [
"print(show(preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without teacher forcing, runtime is also variable from iteration to iteration --\n",
"the model stops when it generates an \"end sequence\" or padding token,\n",
"which is not deterministic thanks to the dropout layers.\n",
"For similar reasons, runtime is variable across inputs.\n",
"\n",
"The variable runtime of autoregressive generation\n",
"is also not great for scaling.\n",
"In a distributed setting, as required for large scale,\n",
"forward passes need to be synced across devices,\n",
"and if one device is generating a batch of much longer sequences,\n",
"it will cause all the others to idle while they wait on it to finish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t76MSVRXV0V7"
},
"source": [
"Let's turn our model into a `TransformerLitModel`\n",
"so we can run with teacher forcing.\n",
"\n",
"> You may be wondering:\n",
" why isn't teacher forcing part of the PyTorch module?\n",
" In general, the `LightningModule`\n",
" should encapsulate things that are needed in training, validation, and testing\n",
" but not during inference.\n",
" The teacher forcing trick fits this paradigm,\n",
" even though it's so critical to what makes Transformers powerful. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qrHRKHowdDi"
},
"outputs": [],
"source": [
"import text_recognizer.lit_models\n",
"\n",
"lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlNaFqR50Oid"
},
"source": [
"Now we can use `.teacher_forward` if we also provide the target `ys`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lpZdqXS5wn0F"
},
"outputs": [],
"source": [
"forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Zx9SmsN0QLT"
},
"source": [
"This may not run faster than the `rnt.forward`,\n",
"since generations are always the maximum possible length,\n",
"but runtimes and output lengths are deterministic and constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tu-XNYpi0Qvi"
},
"source": [
"Forcing doesn't necessarily make our predictions better.\n",
"They remain highly repetitive gibberish."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JcEgify9w0sv"
},
"outputs": [],
"source": [
"forcing_preds = torch.argmax(forcing_outs, dim=0)\n",
"\n",
"print(show(forcing_preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xn6GGNzc9a3o"
},
"source": [
"## Training the `ResNetTransformer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uvZYsuSyWUXe"
},
"source": [
"We're finally ready to train this model on full paragraphs of handwritten text!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cJwC7b720Sd"
},
"source": [
"This is a more serious model --\n",
"it's the one we use in the\n",
"[deployed TextRecognizer application](http://fsdl.me/app).\n",
"It's much larger than the models we've seen this far,\n",
"so it can easily outstrip available compute resources,\n",
"in particular GPU memory.\n",
"\n",
"To help, we use\n",
"[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n",
"which shrinks the size of most of our floats by half,\n",
"which reduces memory consumption and can speed up computation.\n",
"\n",
"If your GPU has less than 8GB of available RAM,\n",
"you'll see a \"CUDA out of memory\" `RuntimeError`,\n",
"which is something of a\n",
"[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n",
"In this case, you can resolve it by reducing the `--batch_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w1mXlhfy04Nm"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"gpus = int(torch.cuda.is_available())\n",
"\n",
"if gpus:\n",
" !nvidia-smi\n",
"else:\n",
" print(\"watch out! working with this model on a typical CPU is not feasible\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "os1vW1rPZ1dy"
},
"source": [
"Even with an okay GPU, like a\n",
"[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n",
"a single epoch of training can take over 10 minutes to run.\n",
"We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n",
"but you can remove those flags to see what full training looks like."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vnF6dWFn4JlZ"
},
"source": [
"It can take a long time (overnight)\n",
"to train this model to decent performance on a single GPU,\n",
"so we'll focus on other pieces for the exercises.\n",
"\n",
"> At the time of writing in mid-2022, the cheapest readily available option\n",
"for training this model to decent performance on this dataset with this codebase\n",
"comes out around $10, using\n",
"[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n",
"See, for example,\n",
"[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n",
"and associated experiment.\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HufjdUZN0t4l",
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"# above %%magic times the cell, useful as a poor man's profiler\n",
"\n",
"%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n",
" --gpus={gpus} --batch_size 16 --precision 16 \\\n",
" --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L6fQ93ju3Iku"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udb1Ekjx3L63"
},
"source": [
"### 🌟 Try out gradient accumulation and other \"training tricks\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpqViB4p3Wfb"
},
"source": [
"Larger batches are helpful not only for increasing parallelization\n",
"and amortizing fixed costs\n",
"but also for getting more reliable gradients.\n",
"Larger batches give gradients with less noise\n",
"and to a point, less gradient noise means faster convergence.\n",
"\n",
"But larger batches result in larger tensors,\n",
"which take up more GPU memory,\n",
"a resource that is tightly constrained\n",
"and device-dependent.\n",
"\n",
"Does that mean we are limited in the quality of our gradients\n",
"due to our machine size?\n",
"\n",
"Not entirely:\n",
"look up the `--accumulate_grad_batches`\n",
"argument to the `pl.Trainer`.\n",
"You should be able to understand why\n",
"it makes it possible to compute the same gradients\n",
"you would find for a batch of size `k * N`\n",
"on a machine that can only run batches up to size `N`.\n",
"\n",
"Accumulating gradients across batches is among the\n",
"[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n",
"Try some of them out!\n",
"Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2vtkmX830y3"
},
"source": [
"### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n",
"\n",
"While training this model to actually fit the whole dataset is infeasible\n",
"as a short exercise on commodity hardware,\n",
"it's practical to train this model to memorize a batch of 16 examples.\n",
"\n",
"Passing `--overfit_batches 1` flag limits the number of training batches to 1\n",
"and turns off\n",
"[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n",
"so that in each epoch, the model just sees the same single batch of data over and over again.\n",
"\n",
"At first, try training the model to a loss of `2.5` --\n",
"it should be doable in 100 epochs or less,\n",
"which is just a few minutes on a commodity GPU.\n",
"\n",
"Once you've got that working,\n",
"crank up the number of epochs by a factor of 10\n",
"and confirm that the loss continues to go down.\n",
"\n",
"Some tips:\n",
"\n",
"- Use `--limit_test_batches 0` to turn off testing.\n",
"We don't need it because we don't care about generalization\n",
"and it's relatively slow because it runs the model autoregressively.\n",
"\n",
"- Use `--help` and look through the model class args\n",
"to find the arguments used to reduce model size.\n",
"\n",
"- By default, there's lots of regularization to prevent overfitting.\n",
"Look through the args for the model class and data class\n",
"for regularization knobs to turn off or down."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab03_transformers.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab05/notebooks/lab04_experiments.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab05/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab05/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" "]
self.ignore_tokens = [self.start_index, self.end_index, self.padding_index]
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
================================================
FILE: lab05/text_recognizer/lit_models/metrics.py
================================================
"""Special-purpose metrics for tracking our model performance."""
from typing import Sequence
import torch
import torchmetrics
class CharacterErrorRate(torchmetrics.CharErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore
preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()]
targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()]
super().update(preds_l, targets_l)
def test_character_error_rate():
metric = CharacterErrorRate([0, 1])
X = torch.tensor(
[
[0, 2, 2, 3, 3, 1], # error will be 0
[0, 2, 1, 1, 1, 1], # error will be .75
[0, 2, 2, 4, 4, 1], # error will be .5
]
)
Y = torch.tensor(
[
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
]
)
metric(X, Y)
assert metric.compute() == sum([0, 0.75, 0.5]) / 3
if __name__ == "__main__":
test_character_error_rate()
================================================
FILE: lab05/text_recognizer/lit_models/transformer.py
================================================
"""An encoder-decoder Transformer model"""
from typing import List, Sequence
import torch
from .base import BaseImageToTextLitModel
from .util import replace_after
class TransformerLitModel(BaseImageToTextLitModel):
"""
Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module.
The module must implement an encode and decode method, and the forward method
should be the forward pass during production inference.
"""
def __init__(self, model, args=None):
super().__init__(model, args)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index)
def forward(self, x):
return self.model(x)
def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Uses provided sequence y as guide for non-autoregressive encoding-decoding of x.
Parameters
----------
x
Batch of images to be encoded. See self.model.encode for shape information.
y
Batch of ground truth output sequences.
Returns
-------
torch.Tensor
(B, C, Sy) logits
"""
x = self.model.encode(x)
output = self.model.decode(x, y) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("train/loss", loss)
outputs = {"loss": loss}
if self.is_logged_batch():
preds = self.get_preds(logits)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs})
return outputs
def validation_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("validation/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def test_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("test/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def map(self, ks: Sequence[int], ignore: bool = True) -> str:
"""Maps an iterable of integers to a string using the lit model's mapping."""
if ignore:
return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens])
else:
return "".join([self.mapping[k] for k in ks])
def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]:
"""Maps a list of lists of integers to a list of strings using the lit model's mapping."""
return [self.map(k, ignore) for k in ks]
def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor:
"""Converts logit-like Tensors into prediction indices, optionally overwritten after end token index.
Parameters
----------
logitlikes
(B, C, Sy) Tensor with classes as second dimension. The largest value is the one
whose index we will return. Logits, logprobs, and probs are all acceptable.
replace_after_end
Whether to replace values after the first appearance of the end token with the padding token.
Returns
-------
torch.Tensor
(B, Sy) Tensor of integers in [0, C-1] representing predictions.
"""
raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy)
if replace_after_end:
return replace_after(raw, self.end_index, self.padding_index) # (B, Sy)
else:
return raw # (B, Sy)
================================================
FILE: lab05/text_recognizer/lit_models/util.py
================================================
from typing import Union
import torch
def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:
"""Return indices of first appearance of element in x, collapsing along dim.
Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9
Parameters
----------
x
One or two-dimensional Tensor to search for element.
element
Item to search for inside x.
dim
Dimension of Tensor to collapse over.
Returns
-------
torch.Tensor
Indices where element occurs in x. If element is not found,
return length of x along dim. One dimension smaller than x.
Raises
------
ValueError
if x is not a 1 or 2 dimensional Tensor
Examples
--------
>>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)
tensor([2, 1, 3, 0])
>>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)
tensor(0)
"""
if x.dim() > 2 or x.dim() == 0:
raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}")
matches = x == element
first_appearance_mask = (matches.cumsum(dim) == 1) & matches
does_match, match_index = first_appearance_mask.max(dim)
first_inds = torch.where(does_match, match_index, x.shape[dim])
return first_inds
def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor:
"""Replace all values in each row of 2d Tensor x after the first appearance of element with replace.
Parameters
----------
x
Two-dimensional Tensor (shape denoted (B, S)) to replace values in.
element
Item to search for inside x.
replace
Item that replaces entries that appear after element.
Returns
-------
outs
New Tensor of same shape as x with values after element replaced.
Examples
--------
>>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4)
tensor([[1, 2, 3],
[2, 3, 4],
[1, 1, 1],
[3, 4, 4]])
"""
first_appearances = first_appearance(x, element, dim=1) # (B,)
indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,)
outs = torch.where(
indices[None, :] <= first_appearances[:, None], # if index is before first appearance
x, # return the value from x
replace, # otherwise, return the replacement value
)
return outs # (B, S)
================================================
FILE: lab05/text_recognizer/metadata/emnist.py
================================================
from pathlib import Path
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
NUM_SPECIAL_TOKENS = 4
INPUT_SHAPE = (28, 28)
DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor()
OUTPUT_DIMS = (1,)
MAPPING = [
"",
" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab05/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab05/text_recognizer/metadata/iam.py
================================================
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam"
EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be
LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates
================================================
FILE: lab05/text_recognizer/metadata/iam_lines.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines"
IMAGE_SCALE_FACTOR = 2
CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate
IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR
IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (89, 1)
MAPPING = emnist.MAPPING
================================================
FILE: lab05/text_recognizer/metadata/iam_paragraphs.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN]
IMAGE_SCALE_FACTOR = 2
IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH)
MAX_LABEL_LENGTH = 682
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1)
================================================
FILE: lab05/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab05/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab05/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
from .resnet_transformer import ResnetTransformer
from .line_cnn_transformer import LineCNNTransformer
================================================
FILE: lab05/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab05/text_recognizer/models/line_cnn.py
================================================
"""Basic building blocks for convolutional models over lines of text."""
import argparse
import math
from typing import Any, Dict, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
# Common type hints
Param2D = Union[int, Tuple[int, int]]
CONV_DIM = 32
FC_DIM = 512
FC_DROPOUT = 0.2
WINDOW_WIDTH = 16
WINDOW_STRIDE = 8
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: Param2D = 3,
stride: Param2D = 1,
padding: Param2D = 1,
) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class LineCNN(nn.Module):
"""
Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits
"""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.args = vars(args) if args is not None else {}
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
_C, H, _W = data_config["input_dims"]
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
# Input is (1, H, W)
self.convs = nn.Sequential(
ConvBlock(1, conv_dim),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim, stride=2),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim * 2, stride=2),
ConvBlock(conv_dim * 2, conv_dim * 2),
ConvBlock(conv_dim * 2, conv_dim * 4, stride=2),
ConvBlock(conv_dim * 4, conv_dim * 4),
ConvBlock(
conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0
),
)
self.fc1 = nn.Linear(fc_dim, fc_dim)
self.dropout = nn.Dropout(fc_dropout)
self.fc2 = nn.Linear(fc_dim, self.num_classes)
self._init_weights()
def _init_weights(self):
"""
Initialize weights in a better way than default.
See https://github.com/pytorch/pytorch/issues/18182
"""
for m in self.modules():
if type(m) in {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}:
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the LineCNN to a black-and-white input image.
Parameters
----------
x
(B, 1, H, W) input image
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and self.window_width
C is self.num_classes
"""
_B, _C, _H, _W = x.shape
x = self.convs(x) # (B, FC_DIM, 1, Sx)
x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM)
x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM)
x = self.dropout(x)
x = self.fc2(x) # (B, S, C)
x = x.permute(0, 2, 1) # -> (B, C, S)
if self.limit_output_length:
x = x[:, :, : self.output_length]
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab05/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab05/text_recognizer/models/line_cnn_transformer.py
================================================
"""Model that combines a LineCNN with a Transformer model for text prediction."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .line_cnn import LineCNN
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding
TF_DIM = 256
TF_FC_DIM = 256
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
class LineCNNTransformer(nn.Module):
"""Process the line through a CNN and process the resulting sequence with a Transformer decoder."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# Instantiate LineCNN with "num_classes" set to self.dim
data_config_for_line_cnn = {**data_config}
data_config_for_line_cnn["mapping"] = list(range(self.dim))
self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args)
# LineCNN outputs (B, E, S) log probs, with E == dim
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.pos_encoder = PositionalEncoding(d_model=self.dim)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(Sx, B, E) logits
"""
x = self.line_cnn(x) # (B, E, Sx)
x = x * math.sqrt(self.dim)
x = x.permute(2, 0, 1) # (Sx, B, E)
x = self.pos_encoder(x) # (Sx, B, E)
return x
def decode(self, x, y):
"""Decode a batch of encoded images x using preceding ground truth y.
Parameters
----------
x
(Sx, B, E) image encoded as a sequence
y
(B, Sy) with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) logits
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict sequences of tokens from input images auto-regressively.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
# Set all tokens after end token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
@staticmethod
def add_to_argparse(parser):
LineCNN.add_to_argparse(parser)
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab05/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab05/text_recognizer/models/resnet_transformer.py
================================================
"""Model combining a ResNet with a Transformer for image-to-sequence tasks."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
import torchvision
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage
TF_DIM = 256
TF_FC_DIM = 1024
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
RESNET_DIM = 512 # hard-coded
class ResnetTransformer(nn.Module):
"""Pass an image through a Resnet and decode the resulting embedding with a Transformer."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
self.mapping = data_config["mapping"]
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# ## Encoder part - should output vector sequence of length self.dim per sample
resnet = torchvision.models.resnet18(weights=None)
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
# encoder_projection will output (B, dim, _H, _W) logits
self.enc_pos_encoder = PositionalEncodingImage(
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
) # Max (Ho, Wo)
# ## Decoder part
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Autoregressively produce sequences of labels from input images.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
output_tokens
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1] # Set the last output token
# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
break
# Set all tokens after end or padding token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if self.encoder_projection.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
(Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps
"""
_B, C, _H, _W = x.shape
if C == 1:
x = x.repeat(1, 3, 1, 1)
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
return x
def decode(self, x, y):
"""Decode a batch of encoded images x with guiding sequences y.
During autoregressive inference, the guiding sequence will be previous predictions.
During training, the guiding sequence will be the ground truth.
Parameters
----------
x
(Sx, B, E) images encoded as sequences of embeddings
y
(B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) batch of logit sequences
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.dec_pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab05/text_recognizer/models/transformer_util.py
================================================
"""Position Encoding and other utilities for Transformers."""
import math
import torch
from torch import Tensor
import torch.nn as nn
class PositionalEncodingImage(nn.Module):
"""
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
"""
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None:
super().__init__()
self.d_model = d_model
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
return pe
def forward(self, x: Tensor) -> Tensor:
"""pytorch.nn.module.forward"""
# x.shape = (B, d_model, H, W)
assert x.shape[1] == self.pe.shape[0] # type: ignore
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
return x
class PositionalEncoding(torch.nn.Module):
"""Classic Attention-is-all-you-need positional encoding."""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape = (S, B, d_model)
assert x.shape[2] == self.pe.shape[2] # type: ignore
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
"""Generate a triangular (size, size) mask."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
================================================
FILE: lab05/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab05/text_recognizer/stems/line.py
================================================
import random
from PIL import Image
from torchvision import transforms
import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.stems.image import ImageStem
class LineStem(ImageStem):
"""A stem for handling images containing a line of text."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.5, 1)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"translate": (0, 0.05),
"scale": (0.4, 1.1),
"shear": (-40, 50),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
if augment:
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
)
class IAMLineStem(ImageStem):
"""A stem for handling images containing lines of text from the IAMLines dataset."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
def embed_crop(crop, augment=augment):
# crop is PIL.image of dtype="L" (so values range from 0 -> 255)
image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
# Resize crop
crop_width, crop_height = crop.size
new_crop_height = metadata.IMAGE_HEIGHT
new_crop_width = int(new_crop_height * (crop_width / crop_height))
if augment:
# Add random stretching
new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR)
# Embed in the image
x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
y = metadata.IMAGE_HEIGHT - new_crop_height
image.paste(crop_resized, (x, y))
return image
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.8, 1.6)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 1,
"shear": (-30, 20),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
pil_transforms_list = [transforms.Lambda(embed_crop)]
if augment:
pil_transforms_list += [
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
self.pil_transforms = transforms.Compose(pil_transforms_list)
================================================
FILE: lab05/text_recognizer/stems/paragraph.py
================================================
"""IAMParagraphs Stem class."""
import torchvision.transforms as transforms
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.stems.image import ImageStem
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
class ParagraphStem(ImageStem):
"""A stem for handling images that contain a paragraph of text."""
def __init__(
self,
augment=False,
color_jitter_kwargs=None,
random_affine_kwargs=None,
random_perspective_kwargs=None,
gaussian_blur_kwargs=None,
sharpness_kwargs=None,
):
super().__init__()
if not augment:
self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)])
else:
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"shear": 6,
"scale": (0.95, 1),
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if random_perspective_kwargs is None:
random_perspective_kwargs = {
"distortion_scale": 0.2,
"p": 0.5,
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if gaussian_blur_kwargs is None:
gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
if sharpness_kwargs is None:
sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
# IMAGE_SHAPE is (576, 640)
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomCrop(
size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant"
),
transforms.RandomAffine(**random_affine_kwargs),
transforms.RandomPerspective(**random_perspective_kwargs),
transforms.GaussianBlur(**gaussian_blur_kwargs),
transforms.RandomAdjustSharpness(**sharpness_kwargs),
]
)
================================================
FILE: lab05/text_recognizer/tests/test_callback_utils.py
================================================
"""Tests for the text_recognizer.callbacks.util module."""
import random
import string
import tempfile
import pytorch_lightning as pl
from text_recognizer.callbacks.util import check_and_warn
def test_check_and_warn_simple():
"""Test the success and failure in the case of a simple class we control."""
class Foo:
pass # a class with no special attributes
letters = string.ascii_lowercase
random_attribute = "".join(random.choices(letters, k=10))
assert check_and_warn(Foo(), random_attribute, "random feature")
assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects")
def test_check_and_warn_tblogger():
"""Test that we return a truthy value when trying to log tables with TensorBoard.
We added check_and_warn in order to prevent a crash if this happens.
"""
tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory())
assert check_and_warn(tblogger, "log_table", "tables")
def test_check_and_warn_wandblogger():
"""Test that we return a falsy value when we try to log tables with W&B.
In adding check_and_warn, we don't want to block the feature in the happy path.
"""
wandblogger = pl.loggers.WandbLogger(anonymous=True)
assert not check_and_warn(wandblogger, "log_table", "tables")
================================================
FILE: lab05/text_recognizer/tests/test_iam.py
================================================
"""Test for data.iam module."""
from text_recognizer.data.iam import IAM
def test_iam_parsed_lines():
"""Tests that we retrieve the same number of line labels and line image cropregions."""
iam = IAM()
iam.prepare_data()
for iam_id in iam.all_ids:
assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id])
def test_iam_data_splits():
"""Fails when any identifiers are shared between training, test, or validation."""
iam = IAM()
iam.prepare_data()
assert not set(iam.train_ids) & set(iam.validation_ids)
assert not set(iam.train_ids) & set(iam.test_ids)
assert not set(iam.validation_ids) & set(iam.test_ids)
================================================
FILE: lab05/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab05/training/__init__.py
================================================
================================================
FILE: lab05/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import callbacks as cb
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--wandb",
action="store_true",
default=False,
help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.",
)
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.loss == "transformer":
lit_model_class = lit_models.TransformerLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/cer":
filename_format += "-validation.cer={validation/cer:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.wandb:
logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train")
logger.watch(model, log_freq=max(100, args.log_every_n_steps))
logger.log_hyperparams(vars(args))
experiment_dir = logger.experiment.dir
callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
if args.wandb and args.loss in ("transformer",):
callbacks.append(cb.ImageToTextLogger())
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
if args.profile:
sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0)
profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir)
profiler.STEP_FUNCTIONS = {"training_step"} # only profile training
else:
profiler = pl.profiler.PassThroughProfiler()
trainer.profiler = profiler
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
if args.wandb:
rank_zero_info("Best model also uploaded to W&B ")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab05/training/tests/test_memorize_iam.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
# tests whether we can achieve a criterion loss
# on a single batch within a certain number of epochs
FAILURE=false
# constants and CLI args set by aiming for <5 min test on commodity GPU,
# including data download step
MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash
CRITERION="${2:-1.0}"
# train on GPU if it's available
GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))')
python ./training/run_experiment.py \
--data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \
--augment_data false --tf_dropout 0.0 \
--gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \
--log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true
python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Memorization test failed at loss criterion $CRITERION"
exit 1
fi
echo "Memorization test passed at loss criterion $CRITERION"
exit 0
================================================
FILE: lab05/training/tests/test_run_experiment.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
echo "running full loop test with CNN on fake data"
python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true
echo "running fast_dev_run test of real model class on real data"
python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \
--fast_dev_run --num_sanity_val_steps 0 \
--num_workers 1 || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Test for run_experiment.py failed"
exit 1
fi
echo "Tests for run_experiment.py passed"
exit 0
================================================
FILE: lab05/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: lab06/.flake8
================================================
[flake8]
select = ANN,B,B9,BLK,C,D,E,F,I,S,W
# only check selected error codes
max-complexity = 12
# C9 - flake8 McCabe Complexity checker -- threshold
max-line-length = 120
# E501 - flake8 -- line length too long, actually handled by black
extend-ignore =
# E W - flake8 PEP style check
E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks
# S - flake8-bandit safety check
S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password
# ANN - flake8-annotations type annotation check
ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some
# D1 - flake8-docstrings docstring style check
D100,D102,D103,D104,D105, # missing docstrings
# D2 D4 - flake8-docstrings docstring style check
D200,D205,D400,D401, # whitespace issues and first line content
# DAR - flake8-darglint docstring correctness check
DAR103, # mismatched or missing type in docstring
application-import-names = app_gradio,text_recognizer,tests,training
# flake8-import-order: which names are first party?
import-order-style = google
# flake8-import-order: which import order style guide do we use?
docstring-convention = numpy
# flake8-docstrings: which docstring style guide do we use?
strictness = short
# darglint: how "strict" are we with docstring completeness?
docstring-style = numpy
# darglint: which docstring style guide do we use?
suppress-none-returning = true
# flake8-annotations: do we allow un-annotated Nones in returns?
mypy-init-return = true
# flake8-annotations: do we allow init to have no return annotation?
per-file-ignores =
# list of case-by-case ignores, see files for details
*/__init__.py:F401,I
*/data/*.py:DAR
data/*.py:F,I
*text_recognizer/util.py:DAR101,F401
*training/run_experiment.py:I202
*app_gradio/app.py:I202
================================================
FILE: lab06/.github/workflows/pre-commit.yml
================================================
name: pre-commit
on:
pull_request:
push:
# allows this Action to be triggered manually
workflow_dispatch:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.10'
- uses: pre-commit/action@v3.0.0
================================================
FILE: lab06/.pre-commit-config.yaml
================================================
repos:
# a set of useful Python-based pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
# list of definitions and supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace # removes any whitespace at the ends of lines
- id: check-toml # check toml syntax by loading all toml files
- id: check-yaml # check yaml syntax by loading all yaml files
- id: check-json # check-json syntax by loading all json files
- id: check-merge-conflict # check for files with merge conflict strings
args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge
- id: check-added-large-files # check that no "large" files have been added
args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server
- id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.)
- id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.)
# black python autoformatting
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
# additional configuration of black in pyproject.toml
# flake8 python linter with all the fixins
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08)
additional_dependencies: [
flake8-bandit, flake8-bugbear, flake8-docstrings,
flake8-import-order, darglint, mypy, pycodestyle, pydocstyle]
args: ["--config", ".flake8"]
# additional configuration of flake8 and extensions in .flake8
# shellcheck-py for linting shell files
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.8.0.4
hooks:
- id: shellcheck
================================================
FILE: lab06/notebooks/lab01_pytorch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab06/notebooks/lab03_transformers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" \", \"\")\n",
"\n",
"idx = random.randint(0, len(xs))\n",
"\n",
"print(show(ys[idx]))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4dT3UCNzTsoc"
},
"source": [
"The `ResnetTransformer` model can run on this data\n",
"if passed the `.config`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXL-vIGRr86D"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMxa-oWyT01E"
},
"source": [
"Our models are now big enough\n",
"that we want to make use of GPU acceleration\n",
"as much as we can,\n",
"even when working on single inputs,\n",
"so let's cast to the GPU if we have one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-YyUM8LgvW0w"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"rnt.to(device); xs = xs.to(device); ys = ys.to(device);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-E3UdD4zUJi"
},
"source": [
"First, let's just pass it through the ResNet encoder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-LUUtlvaxrvg"
},
"outputs": [],
"source": [
"resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n",
" # resnet is designed for RGB images, so we replicate the input across channels 3 times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eimgJ5dnywjg"
},
"outputs": [],
"source": [
"resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n",
"plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n",
"plt.axis(\"off\"); plt.colorbar(fraction=0.05);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These embeddings, though generated by random, untrained weights,\n",
"are not entirely useless.\n",
"\n",
"Before neural networks could be effectively\n",
"trained end to end,\n",
"they were often used with frozen random weights\n",
"eveywhere except the final layer\n",
"(see e.g.\n",
"[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n",
"[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n",
"these methods were still competitive, and\n",
"[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n",
"provide a\n",
"[theoretical basis](https://arxiv.org/abs/2011.14522)\n",
"for understanding their performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ye6pW0ETzw2A"
},
"source": [
"The final result, though, is repetitive gibberish --\n",
"at the bare minimum, we need to train the unembedding/readout layer\n",
"in order to get reasonable text."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our architecture includes randomization with dropout,\n",
"so repeated runs of the cell below will generate different outcomes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu3Pa7gLsFMo"
},
"outputs": [],
"source": [
"preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gvCXUbskv6XM"
},
"outputs": [],
"source": [
"print(show(preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without teacher forcing, runtime is also variable from iteration to iteration --\n",
"the model stops when it generates an \"end sequence\" or padding token,\n",
"which is not deterministic thanks to the dropout layers.\n",
"For similar reasons, runtime is variable across inputs.\n",
"\n",
"The variable runtime of autoregressive generation\n",
"is also not great for scaling.\n",
"In a distributed setting, as required for large scale,\n",
"forward passes need to be synced across devices,\n",
"and if one device is generating a batch of much longer sequences,\n",
"it will cause all the others to idle while they wait on it to finish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t76MSVRXV0V7"
},
"source": [
"Let's turn our model into a `TransformerLitModel`\n",
"so we can run with teacher forcing.\n",
"\n",
"> You may be wondering:\n",
" why isn't teacher forcing part of the PyTorch module?\n",
" In general, the `LightningModule`\n",
" should encapsulate things that are needed in training, validation, and testing\n",
" but not during inference.\n",
" The teacher forcing trick fits this paradigm,\n",
" even though it's so critical to what makes Transformers powerful. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qrHRKHowdDi"
},
"outputs": [],
"source": [
"import text_recognizer.lit_models\n",
"\n",
"lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlNaFqR50Oid"
},
"source": [
"Now we can use `.teacher_forward` if we also provide the target `ys`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lpZdqXS5wn0F"
},
"outputs": [],
"source": [
"forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Zx9SmsN0QLT"
},
"source": [
"This may not run faster than the `rnt.forward`,\n",
"since generations are always the maximum possible length,\n",
"but runtimes and output lengths are deterministic and constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tu-XNYpi0Qvi"
},
"source": [
"Forcing doesn't necessarily make our predictions better.\n",
"They remain highly repetitive gibberish."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JcEgify9w0sv"
},
"outputs": [],
"source": [
"forcing_preds = torch.argmax(forcing_outs, dim=0)\n",
"\n",
"print(show(forcing_preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xn6GGNzc9a3o"
},
"source": [
"## Training the `ResNetTransformer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uvZYsuSyWUXe"
},
"source": [
"We're finally ready to train this model on full paragraphs of handwritten text!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cJwC7b720Sd"
},
"source": [
"This is a more serious model --\n",
"it's the one we use in the\n",
"[deployed TextRecognizer application](http://fsdl.me/app).\n",
"It's much larger than the models we've seen this far,\n",
"so it can easily outstrip available compute resources,\n",
"in particular GPU memory.\n",
"\n",
"To help, we use\n",
"[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n",
"which shrinks the size of most of our floats by half,\n",
"which reduces memory consumption and can speed up computation.\n",
"\n",
"If your GPU has less than 8GB of available RAM,\n",
"you'll see a \"CUDA out of memory\" `RuntimeError`,\n",
"which is something of a\n",
"[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n",
"In this case, you can resolve it by reducing the `--batch_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w1mXlhfy04Nm"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"gpus = int(torch.cuda.is_available())\n",
"\n",
"if gpus:\n",
" !nvidia-smi\n",
"else:\n",
" print(\"watch out! working with this model on a typical CPU is not feasible\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "os1vW1rPZ1dy"
},
"source": [
"Even with an okay GPU, like a\n",
"[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n",
"a single epoch of training can take over 10 minutes to run.\n",
"We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n",
"but you can remove those flags to see what full training looks like."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vnF6dWFn4JlZ"
},
"source": [
"It can take a long time (overnight)\n",
"to train this model to decent performance on a single GPU,\n",
"so we'll focus on other pieces for the exercises.\n",
"\n",
"> At the time of writing in mid-2022, the cheapest readily available option\n",
"for training this model to decent performance on this dataset with this codebase\n",
"comes out around $10, using\n",
"[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n",
"See, for example,\n",
"[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n",
"and associated experiment.\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HufjdUZN0t4l",
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"# above %%magic times the cell, useful as a poor man's profiler\n",
"\n",
"%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n",
" --gpus={gpus} --batch_size 16 --precision 16 \\\n",
" --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L6fQ93ju3Iku"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udb1Ekjx3L63"
},
"source": [
"### 🌟 Try out gradient accumulation and other \"training tricks\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpqViB4p3Wfb"
},
"source": [
"Larger batches are helpful not only for increasing parallelization\n",
"and amortizing fixed costs\n",
"but also for getting more reliable gradients.\n",
"Larger batches give gradients with less noise\n",
"and to a point, less gradient noise means faster convergence.\n",
"\n",
"But larger batches result in larger tensors,\n",
"which take up more GPU memory,\n",
"a resource that is tightly constrained\n",
"and device-dependent.\n",
"\n",
"Does that mean we are limited in the quality of our gradients\n",
"due to our machine size?\n",
"\n",
"Not entirely:\n",
"look up the `--accumulate_grad_batches`\n",
"argument to the `pl.Trainer`.\n",
"You should be able to understand why\n",
"it makes it possible to compute the same gradients\n",
"you would find for a batch of size `k * N`\n",
"on a machine that can only run batches up to size `N`.\n",
"\n",
"Accumulating gradients across batches is among the\n",
"[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n",
"Try some of them out!\n",
"Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2vtkmX830y3"
},
"source": [
"### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n",
"\n",
"While training this model to actually fit the whole dataset is infeasible\n",
"as a short exercise on commodity hardware,\n",
"it's practical to train this model to memorize a batch of 16 examples.\n",
"\n",
"Passing `--overfit_batches 1` flag limits the number of training batches to 1\n",
"and turns off\n",
"[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n",
"so that in each epoch, the model just sees the same single batch of data over and over again.\n",
"\n",
"At first, try training the model to a loss of `2.5` --\n",
"it should be doable in 100 epochs or less,\n",
"which is just a few minutes on a commodity GPU.\n",
"\n",
"Once you've got that working,\n",
"crank up the number of epochs by a factor of 10\n",
"and confirm that the loss continues to go down.\n",
"\n",
"Some tips:\n",
"\n",
"- Use `--limit_test_batches 0` to turn off testing.\n",
"We don't need it because we don't care about generalization\n",
"and it's relatively slow because it runs the model autoregressively.\n",
"\n",
"- Use `--help` and look through the model class args\n",
"to find the arguments used to reduce model size.\n",
"\n",
"- By default, there's lots of regularization to prevent overfitting.\n",
"Look through the args for the model class and data class\n",
"for regularization knobs to turn off or down."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab03_transformers.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab06/notebooks/lab04_experiments.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab06/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab06/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" "]
self.ignore_tokens = [self.start_index, self.end_index, self.padding_index]
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
================================================
FILE: lab06/text_recognizer/lit_models/metrics.py
================================================
"""Special-purpose metrics for tracking our model performance."""
from typing import Sequence
import torch
import torchmetrics
class CharacterErrorRate(torchmetrics.CharErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore
preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()]
targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()]
super().update(preds_l, targets_l)
def test_character_error_rate():
metric = CharacterErrorRate([0, 1])
X = torch.tensor(
[
[0, 2, 2, 3, 3, 1], # error will be 0
[0, 2, 1, 1, 1, 1], # error will be .75
[0, 2, 2, 4, 4, 1], # error will be .5
]
)
Y = torch.tensor(
[
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
]
)
metric(X, Y)
assert metric.compute() == sum([0, 0.75, 0.5]) / 3
if __name__ == "__main__":
test_character_error_rate()
================================================
FILE: lab06/text_recognizer/lit_models/transformer.py
================================================
"""An encoder-decoder Transformer model"""
from typing import List, Sequence
import torch
from .base import BaseImageToTextLitModel
from .util import replace_after
class TransformerLitModel(BaseImageToTextLitModel):
"""
Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module.
The module must implement an encode and decode method, and the forward method
should be the forward pass during production inference.
"""
def __init__(self, model, args=None):
super().__init__(model, args)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index)
def forward(self, x):
return self.model(x)
def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Uses provided sequence y as guide for non-autoregressive encoding-decoding of x.
Parameters
----------
x
Batch of images to be encoded. See self.model.encode for shape information.
y
Batch of ground truth output sequences.
Returns
-------
torch.Tensor
(B, C, Sy) logits
"""
x = self.model.encode(x)
output = self.model.decode(x, y) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("train/loss", loss)
outputs = {"loss": loss}
if self.is_logged_batch():
preds = self.get_preds(logits)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs})
return outputs
def validation_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("validation/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def test_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("test/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def map(self, ks: Sequence[int], ignore: bool = True) -> str:
"""Maps an iterable of integers to a string using the lit model's mapping."""
if ignore:
return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens])
else:
return "".join([self.mapping[k] for k in ks])
def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]:
"""Maps a list of lists of integers to a list of strings using the lit model's mapping."""
return [self.map(k, ignore) for k in ks]
def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor:
"""Converts logit-like Tensors into prediction indices, optionally overwritten after end token index.
Parameters
----------
logitlikes
(B, C, Sy) Tensor with classes as second dimension. The largest value is the one
whose index we will return. Logits, logprobs, and probs are all acceptable.
replace_after_end
Whether to replace values after the first appearance of the end token with the padding token.
Returns
-------
torch.Tensor
(B, Sy) Tensor of integers in [0, C-1] representing predictions.
"""
raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy)
if replace_after_end:
return replace_after(raw, self.end_index, self.padding_index) # (B, Sy)
else:
return raw # (B, Sy)
================================================
FILE: lab06/text_recognizer/lit_models/util.py
================================================
from typing import Union
import torch
def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:
"""Return indices of first appearance of element in x, collapsing along dim.
Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9
Parameters
----------
x
One or two-dimensional Tensor to search for element.
element
Item to search for inside x.
dim
Dimension of Tensor to collapse over.
Returns
-------
torch.Tensor
Indices where element occurs in x. If element is not found,
return length of x along dim. One dimension smaller than x.
Raises
------
ValueError
if x is not a 1 or 2 dimensional Tensor
Examples
--------
>>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)
tensor([2, 1, 3, 0])
>>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)
tensor(0)
"""
if x.dim() > 2 or x.dim() == 0:
raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}")
matches = x == element
first_appearance_mask = (matches.cumsum(dim) == 1) & matches
does_match, match_index = first_appearance_mask.max(dim)
first_inds = torch.where(does_match, match_index, x.shape[dim])
return first_inds
def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor:
"""Replace all values in each row of 2d Tensor x after the first appearance of element with replace.
Parameters
----------
x
Two-dimensional Tensor (shape denoted (B, S)) to replace values in.
element
Item to search for inside x.
replace
Item that replaces entries that appear after element.
Returns
-------
outs
New Tensor of same shape as x with values after element replaced.
Examples
--------
>>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4)
tensor([[1, 2, 3],
[2, 3, 4],
[1, 1, 1],
[3, 4, 4]])
"""
first_appearances = first_appearance(x, element, dim=1) # (B,)
indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,)
outs = torch.where(
indices[None, :] <= first_appearances[:, None], # if index is before first appearance
x, # return the value from x
replace, # otherwise, return the replacement value
)
return outs # (B, S)
================================================
FILE: lab06/text_recognizer/metadata/emnist.py
================================================
from pathlib import Path
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
NUM_SPECIAL_TOKENS = 4
INPUT_SHAPE = (28, 28)
DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor()
OUTPUT_DIMS = (1,)
MAPPING = [
"",
" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab06/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab06/text_recognizer/metadata/iam.py
================================================
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam"
EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be
LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates
================================================
FILE: lab06/text_recognizer/metadata/iam_lines.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines"
IMAGE_SCALE_FACTOR = 2
CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate
IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR
IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (89, 1)
MAPPING = emnist.MAPPING
================================================
FILE: lab06/text_recognizer/metadata/iam_paragraphs.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN]
# must match IMAGE_SCALE_FACTOR for IAMLines to be compatible with synthetic paragraphs
IMAGE_SCALE_FACTOR = 2
IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH)
MAX_LABEL_LENGTH = 682
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1)
================================================
FILE: lab06/text_recognizer/metadata/iam_synthetic_paragraphs.py
================================================
import text_recognizer.metadata.iam_paragraphs as iam_paragraphs
import text_recognizer.metadata.shared as shared
NEW_LINE_TOKEN = iam_paragraphs.NEW_LINE_TOKEN
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs"
EXPECTED_BATCH_SIZE = 64
EXPECTED_GPUS = 8
EXPECTED_STEPS = 40
# set the dataset's length based on parameters during typical training
DATASET_LEN = EXPECTED_BATCH_SIZE * EXPECTED_GPUS * EXPECTED_STEPS
================================================
FILE: lab06/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab06/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab06/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
from .resnet_transformer import ResnetTransformer
from .line_cnn_transformer import LineCNNTransformer
================================================
FILE: lab06/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab06/text_recognizer/models/line_cnn.py
================================================
"""Basic building blocks for convolutional models over lines of text."""
import argparse
import math
from typing import Any, Dict, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
# Common type hints
Param2D = Union[int, Tuple[int, int]]
CONV_DIM = 32
FC_DIM = 512
FC_DROPOUT = 0.2
WINDOW_WIDTH = 16
WINDOW_STRIDE = 8
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: Param2D = 3,
stride: Param2D = 1,
padding: Param2D = 1,
) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class LineCNN(nn.Module):
"""
Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits
"""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.args = vars(args) if args is not None else {}
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
_C, H, _W = data_config["input_dims"]
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
# Input is (1, H, W)
self.convs = nn.Sequential(
ConvBlock(1, conv_dim),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim, stride=2),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim * 2, stride=2),
ConvBlock(conv_dim * 2, conv_dim * 2),
ConvBlock(conv_dim * 2, conv_dim * 4, stride=2),
ConvBlock(conv_dim * 4, conv_dim * 4),
ConvBlock(
conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0
),
)
self.fc1 = nn.Linear(fc_dim, fc_dim)
self.dropout = nn.Dropout(fc_dropout)
self.fc2 = nn.Linear(fc_dim, self.num_classes)
self._init_weights()
def _init_weights(self):
"""
Initialize weights in a better way than default.
See https://github.com/pytorch/pytorch/issues/18182
"""
for m in self.modules():
if type(m) in {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}:
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the LineCNN to a black-and-white input image.
Parameters
----------
x
(B, 1, H, W) input image
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and self.window_width
C is self.num_classes
"""
_B, _C, _H, _W = x.shape
x = self.convs(x) # (B, FC_DIM, 1, Sx)
x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM)
x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM)
x = self.dropout(x)
x = self.fc2(x) # (B, S, C)
x = x.permute(0, 2, 1) # -> (B, C, S)
if self.limit_output_length:
x = x[:, :, : self.output_length]
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab06/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab06/text_recognizer/models/line_cnn_transformer.py
================================================
"""Model that combines a LineCNN with a Transformer model for text prediction."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .line_cnn import LineCNN
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding
TF_DIM = 256
TF_FC_DIM = 256
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
class LineCNNTransformer(nn.Module):
"""Process the line through a CNN and process the resulting sequence with a Transformer decoder."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# Instantiate LineCNN with "num_classes" set to self.dim
data_config_for_line_cnn = {**data_config}
data_config_for_line_cnn["mapping"] = list(range(self.dim))
self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args)
# LineCNN outputs (B, E, S) log probs, with E == dim
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.pos_encoder = PositionalEncoding(d_model=self.dim)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(Sx, B, E) logits
"""
x = self.line_cnn(x) # (B, E, Sx)
x = x * math.sqrt(self.dim)
x = x.permute(2, 0, 1) # (Sx, B, E)
x = self.pos_encoder(x) # (Sx, B, E)
return x
def decode(self, x, y):
"""Decode a batch of encoded images x using preceding ground truth y.
Parameters
----------
x
(Sx, B, E) image encoded as a sequence
y
(B, Sy) with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) logits
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict sequences of tokens from input images auto-regressively.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
# Set all tokens after end token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
@staticmethod
def add_to_argparse(parser):
LineCNN.add_to_argparse(parser)
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab06/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab06/text_recognizer/models/resnet_transformer.py
================================================
"""Model combining a ResNet with a Transformer for image-to-sequence tasks."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
import torchvision
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage
TF_DIM = 256
TF_FC_DIM = 1024
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
RESNET_DIM = 512 # hard-coded
class ResnetTransformer(nn.Module):
"""Pass an image through a Resnet and decode the resulting embedding with a Transformer."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
self.mapping = data_config["mapping"]
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# ## Encoder part - should output vector sequence of length self.dim per sample
resnet = torchvision.models.resnet18(weights=None)
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
# encoder_projection will output (B, dim, _H, _W) logits
self.enc_pos_encoder = PositionalEncodingImage(
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
) # Max (Ho, Wo)
# ## Decoder part
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Autoregressively produce sequences of labels from input images.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
output_tokens
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1] # Set the last output token
# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
break
# Set all tokens after end or padding token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if self.encoder_projection.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
(Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps
"""
_B, C, _H, _W = x.shape
if C == 1:
x = x.repeat(1, 3, 1, 1)
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
return x
def decode(self, x, y):
"""Decode a batch of encoded images x with guiding sequences y.
During autoregressive inference, the guiding sequence will be previous predictions.
During training, the guiding sequence will be the ground truth.
Parameters
----------
x
(Sx, B, E) images encoded as sequences of embeddings
y
(B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) batch of logit sequences
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.dec_pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab06/text_recognizer/models/transformer_util.py
================================================
"""Position Encoding and other utilities for Transformers."""
import math
import torch
from torch import Tensor
import torch.nn as nn
class PositionalEncodingImage(nn.Module):
"""
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
"""
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None:
super().__init__()
self.d_model = d_model
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
return pe
def forward(self, x: Tensor) -> Tensor:
"""pytorch.nn.module.forward"""
# x.shape = (B, d_model, H, W)
assert x.shape[1] == self.pe.shape[0] # type: ignore
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
return x
class PositionalEncoding(torch.nn.Module):
"""Classic Attention-is-all-you-need positional encoding."""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape = (S, B, d_model)
assert x.shape[2] == self.pe.shape[2] # type: ignore
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
"""Generate a triangular (size, size) mask."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
================================================
FILE: lab06/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab06/text_recognizer/stems/line.py
================================================
import random
from PIL import Image
from torchvision import transforms
import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.stems.image import ImageStem
class LineStem(ImageStem):
"""A stem for handling images containing a line of text."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.5, 1)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"translate": (0, 0.05),
"scale": (0.4, 1.1),
"shear": (-40, 50),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
if augment:
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
)
class IAMLineStem(ImageStem):
"""A stem for handling images containing lines of text from the IAMLines dataset."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
def embed_crop(crop, augment=augment):
# crop is PIL.image of dtype="L" (so values range from 0 -> 255)
image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
# Resize crop
crop_width, crop_height = crop.size
new_crop_height = metadata.IMAGE_HEIGHT
new_crop_width = int(new_crop_height * (crop_width / crop_height))
if augment:
# Add random stretching
new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR)
# Embed in the image
x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
y = metadata.IMAGE_HEIGHT - new_crop_height
image.paste(crop_resized, (x, y))
return image
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.8, 1.6)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 1,
"shear": (-30, 20),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
pil_transforms_list = [transforms.Lambda(embed_crop)]
if augment:
pil_transforms_list += [
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
self.pil_transforms = transforms.Compose(pil_transforms_list)
================================================
FILE: lab06/text_recognizer/stems/paragraph.py
================================================
"""IAMParagraphs Stem class."""
import torchvision.transforms as transforms
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.stems.image import ImageStem
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
class ParagraphStem(ImageStem):
"""A stem for handling images that contain a paragraph of text."""
def __init__(
self,
augment=False,
color_jitter_kwargs=None,
random_affine_kwargs=None,
random_perspective_kwargs=None,
gaussian_blur_kwargs=None,
sharpness_kwargs=None,
):
super().__init__()
if not augment:
self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)])
else:
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"shear": 6,
"scale": (0.95, 1),
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if random_perspective_kwargs is None:
random_perspective_kwargs = {
"distortion_scale": 0.2,
"p": 0.5,
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if gaussian_blur_kwargs is None:
gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
if sharpness_kwargs is None:
sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
# IMAGE_SHAPE is (576, 640)
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomCrop(
size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant"
),
transforms.RandomAffine(**random_affine_kwargs),
transforms.RandomPerspective(**random_perspective_kwargs),
transforms.GaussianBlur(**gaussian_blur_kwargs),
transforms.RandomAdjustSharpness(**sharpness_kwargs),
]
)
================================================
FILE: lab06/text_recognizer/tests/test_callback_utils.py
================================================
"""Tests for the text_recognizer.callbacks.util module."""
import random
import string
import tempfile
import pytorch_lightning as pl
from text_recognizer.callbacks.util import check_and_warn
def test_check_and_warn_simple():
"""Test the success and failure in the case of a simple class we control."""
class Foo:
pass # a class with no special attributes
letters = string.ascii_lowercase
random_attribute = "".join(random.choices(letters, k=10))
assert check_and_warn(Foo(), random_attribute, "random feature")
assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects")
def test_check_and_warn_tblogger():
"""Test that we return a truthy value when trying to log tables with TensorBoard.
We added check_and_warn in order to prevent a crash if this happens.
"""
tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory())
assert check_and_warn(tblogger, "log_table", "tables")
def test_check_and_warn_wandblogger():
"""Test that we return a falsy value when we try to log tables with W&B.
In adding check_and_warn, we don't want to block the feature in the happy path.
"""
wandblogger = pl.loggers.WandbLogger(anonymous=True)
assert not check_and_warn(wandblogger, "log_table", "tables")
================================================
FILE: lab06/text_recognizer/tests/test_iam.py
================================================
"""Test for data.iam module."""
from text_recognizer.data.iam import IAM
def test_iam_parsed_lines():
"""Tests that we retrieve the same number of line labels and line image cropregions."""
iam = IAM()
iam.prepare_data()
for iam_id in iam.all_ids:
assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id])
def test_iam_data_splits():
"""Fails when any identifiers are shared between training, test, or validation."""
iam = IAM()
iam.prepare_data()
assert not set(iam.train_ids) & set(iam.validation_ids)
assert not set(iam.train_ids) & set(iam.test_ids)
assert not set(iam.validation_ids) & set(iam.test_ids)
================================================
FILE: lab06/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab06/training/__init__.py
================================================
================================================
FILE: lab06/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import callbacks as cb
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--wandb",
action="store_true",
default=False,
help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.",
)
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.loss == "transformer":
lit_model_class = lit_models.TransformerLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/cer":
filename_format += "-validation.cer={validation/cer:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.wandb:
logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train")
logger.watch(model, log_freq=max(100, args.log_every_n_steps))
logger.log_hyperparams(vars(args))
experiment_dir = logger.experiment.dir
callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
if args.wandb and args.loss in ("transformer",):
callbacks.append(cb.ImageToTextLogger())
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
if args.profile:
sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0)
profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir)
profiler.STEP_FUNCTIONS = {"training_step"} # only profile training
else:
profiler = pl.profiler.PassThroughProfiler()
trainer.profiler = profiler
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
if args.wandb:
rank_zero_info("Best model also uploaded to W&B ")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab06/training/tests/test_memorize_iam.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
# tests whether we can achieve a criterion loss
# on a single batch within a certain number of epochs
FAILURE=false
# constants and CLI args set by aiming for <5 min test on commodity GPU,
# including data download step
MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash
CRITERION="${2:-1.0}"
# train on GPU if it's available
GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))')
python ./training/run_experiment.py \
--data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \
--augment_data false --tf_dropout 0.0 \
--gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \
--log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true
python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Memorization test failed at loss criterion $CRITERION"
exit 1
fi
echo "Memorization test passed at loss criterion $CRITERION"
exit 0
================================================
FILE: lab06/training/tests/test_run_experiment.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
echo "running full loop test with CNN on fake data"
python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true
echo "running fast_dev_run test of real model class on real data"
python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \
--fast_dev_run --num_sanity_val_steps 0 \
--num_workers 1 || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Test for run_experiment.py failed"
exit 1
fi
echo "Tests for run_experiment.py passed"
exit 0
================================================
FILE: lab06/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: lab07/.flake8
================================================
[flake8]
select = ANN,B,B9,BLK,C,D,E,F,I,S,W
# only check selected error codes
max-complexity = 12
# C9 - flake8 McCabe Complexity checker -- threshold
max-line-length = 120
# E501 - flake8 -- line length too long, actually handled by black
extend-ignore =
# E W - flake8 PEP style check
E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks
# S - flake8-bandit safety check
S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password
# ANN - flake8-annotations type annotation check
ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some
# D1 - flake8-docstrings docstring style check
D100,D102,D103,D104,D105, # missing docstrings
# D2 D4 - flake8-docstrings docstring style check
D200,D205,D400,D401, # whitespace issues and first line content
# DAR - flake8-darglint docstring correctness check
DAR103, # mismatched or missing type in docstring
application-import-names = app_gradio,text_recognizer,tests,training
# flake8-import-order: which names are first party?
import-order-style = google
# flake8-import-order: which import order style guide do we use?
docstring-convention = numpy
# flake8-docstrings: which docstring style guide do we use?
strictness = short
# darglint: how "strict" are we with docstring completeness?
docstring-style = numpy
# darglint: which docstring style guide do we use?
suppress-none-returning = true
# flake8-annotations: do we allow un-annotated Nones in returns?
mypy-init-return = true
# flake8-annotations: do we allow init to have no return annotation?
per-file-ignores =
# list of case-by-case ignores, see files for details
*/__init__.py:F401,I
*/data/*.py:DAR
data/*.py:F,I
*text_recognizer/util.py:DAR101,F401
*training/run_experiment.py:I202
*app_gradio/app.py:I202
================================================
FILE: lab07/.github/workflows/pre-commit.yml
================================================
name: pre-commit
on:
pull_request:
push:
# allows this Action to be triggered manually
workflow_dispatch:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.10'
- uses: pre-commit/action@v3.0.0
================================================
FILE: lab07/.pre-commit-config.yaml
================================================
repos:
# a set of useful Python-based pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
# list of definitions and supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace # removes any whitespace at the ends of lines
- id: check-toml # check toml syntax by loading all toml files
- id: check-yaml # check yaml syntax by loading all yaml files
- id: check-json # check-json syntax by loading all json files
- id: check-merge-conflict # check for files with merge conflict strings
args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge
- id: check-added-large-files # check that no "large" files have been added
args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server
- id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.)
- id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.)
# black python autoformatting
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
# additional configuration of black in pyproject.toml
# flake8 python linter with all the fixins
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08)
additional_dependencies: [
flake8-bandit, flake8-bugbear, flake8-docstrings,
flake8-import-order, darglint, mypy, pycodestyle, pydocstyle]
args: ["--config", ".flake8"]
# additional configuration of flake8 and extensions in .flake8
# shellcheck-py for linting shell files
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.8.0.4
hooks:
- id: shellcheck
================================================
FILE: lab07/api_serverless/Dockerfile
================================================
# Starting from an official AWS image
# Keep any dependencies and versions in this file aligned with the environment.yml and Makefile
FROM public.ecr.aws/lambda/python:3.10
# Install Python dependencies
COPY requirements/prod.txt ./requirements.txt
RUN pip install --upgrade pip==23.1.2
RUN pip install -r requirements.txt
# Copy only the relevant directories and files
# note that we use a .dockerignore file to avoid copying logs etc.
COPY text_recognizer/ ./text_recognizer
COPY api_serverless/api.py ./api.py
CMD ["api.handler"]
================================================
FILE: lab07/api_serverless/__init__.py
================================================
"""Cloud function-backed API for paragraph recognition."""
================================================
FILE: lab07/api_serverless/api.py
================================================
"""AWS Lambda function serving text_recognizer predictions."""
import json
from PIL import ImageStat
from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer
import text_recognizer.util as util
model = ParagraphTextRecognizer()
def handler(event, _context):
"""Provide main prediction API."""
print("INFO loading image")
image = _load_image(event)
if image is None:
return {"statusCode": 400, "message": "neither image_url nor image found in event"}
print("INFO image loaded")
print("INFO starting inference")
pred = model.predict(image)
print("INFO inference complete")
image_stat = ImageStat.Stat(image)
print("METRIC image_mean_intensity {}".format(image_stat.mean[0]))
print("METRIC image_area {}".format(image.size[0] * image.size[1]))
print("METRIC pred_length {}".format(len(pred)))
print("INFO pred {}".format(pred))
return {"pred": str(pred)}
def _load_image(event):
event = _from_string(event)
event = _from_string(event.get("body", event))
image_url = event.get("image_url")
if image_url is not None:
print("INFO url {}".format(image_url))
return util.read_image_pil(image_url, grayscale=True)
else:
image = event.get("image")
if image is not None:
print("INFO reading image from event")
return util.read_b64_image(image, grayscale=True)
else:
return None
def _from_string(event):
if isinstance(event, str):
return json.loads(event)
else:
return event
================================================
FILE: lab07/app_gradio/Dockerfile
================================================
# The "buster" flavor of the official docker Python image is based on Debian and includes common packages.
# Keep any dependencies and versions in this file aligned with the environment.yml and Makefile
FROM python:3.10-buster
# Create the working directory
# set -x prints commands and set -e causes us to stop on errors
RUN set -ex && mkdir /repo
WORKDIR /repo
# Install Python dependencies
COPY requirements/prod.txt ./requirements.txt
RUN pip install --upgrade pip==23.1.2
RUN pip install -r requirements.txt
ENV PYTHONPATH ".:"
# Copy only the relevant directories
# note that we use a .dockerignore file to avoid copying logs etc.
COPY text_recognizer/ ./text_recognizer
COPY app_gradio/ ./app_gradio
# Use docker run -it --rm -p $PORT:11717 to run the web server and listen on host $PORT
# add --help to see help for the Python script
ENTRYPOINT ["python3", "app_gradio/app.py", "--port", "11717"]
================================================
FILE: lab07/app_gradio/README.md
================================================
## Full-Paragraph Optical Character Recognition
For more on how this application works,
[check out the GitHub repo](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022).
### Flagging
If the model outputs in the top-right are wrong in some way,
let us know by clicking the "flagging" buttons underneath.
We'll analyze the results with
[Gantry](https://gantry.io/blog/introducing-gantry/)
and use them to improve the model!
================================================
FILE: lab07/app_gradio/__init__.py
================================================
================================================
FILE: lab07/app_gradio/app.py
================================================
"""Provide an image of handwritten text and get back out a string!"""
import argparse
import json
import logging
import os
from pathlib import Path
from typing import Callable
import gradio as gr
from PIL import ImageStat
from PIL.Image import Image
import requests
from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer
import text_recognizer.util as util
os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
logging.basicConfig(level=logging.INFO)
APP_DIR = Path(__file__).resolve().parent # what is the directory for this application?
FAVICON = APP_DIR / "1f95e.png" # path to a small image for display in browser tab and social media
README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
DEFAULT_PORT = 11700
def main(args):
predictor = PredictorBackend(url=args.model_url)
frontend = make_frontend(
predictor.run,
)
frontend.launch(
server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104
server_port=args.port, # set a port to bind to, failing if unavailable
share=True, # should we create a (temporary) public link on https://gradio.app?
favicon_path=FAVICON, # what icon should we display in the address bar?
)
def make_frontend(
fn: Callable[[Image], str],
):
"""Creates a gradio.Interface frontend for an image to text function."""
examples_dir = Path("text_recognizer") / "tests" / "support" / "paragraphs"
example_fnames = [elem for elem in os.listdir(examples_dir) if elem.endswith(".png")]
example_paths = [examples_dir / fname for fname in example_fnames]
examples = [[str(path)] for path in example_paths]
allow_flagging = "never"
readme = _load_readme(with_logging=allow_flagging == "manual")
# build a basic browser interface to a Python function
frontend = gr.Interface(
fn=fn, # which Python function are we interacting with?
outputs=gr.components.Textbox(), # what output widgets does it need? the default text widget
# what input widgets does it need? we configure an image widget
inputs=gr.components.Image(type="pil", label="Handwritten Text"),
title="📝 Text Recognizer", # what should we display at the top of the page?
thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media?
description=__doc__, # what should we display just above the interface?
article=readme, # what long-form content should we display below the interface?
examples=examples, # which potential inputs should we provide?
cache_examples=False, # should we cache those inputs for faster inference? slows down start
allow_flagging=allow_flagging, # should we show users the option to "flag" outputs?
)
return frontend
class PredictorBackend:
"""Interface to a backend that serves predictions.
To communicate with a backend accessible via a URL, provide the url kwarg.
Otherwise, runs a predictor locally.
"""
def __init__(self, url=None):
if url is not None:
self.url = url
self._predict = self._predict_from_endpoint
else:
model = ParagraphTextRecognizer()
self._predict = model.predict
def run(self, image):
pred, metrics = self._predict_with_metrics(image)
self._log_inference(pred, metrics)
return pred
def _predict_with_metrics(self, image):
pred = self._predict(image)
stats = ImageStat.Stat(image)
metrics = {
"image_mean_intensity": stats.mean,
"image_median": stats.median,
"image_extrema": stats.extrema,
"image_area": image.size[0] * image.size[1],
"pred_length": len(pred),
}
return pred, metrics
def _predict_from_endpoint(self, image):
"""Send an image to an endpoint that accepts JSON and return the predicted text.
The endpoint should expect a base64 representation of the image, encoded as a string,
under the key "image". It should return the predicted text under the key "pred".
Parameters
----------
image
A PIL image of handwritten text to be converted into a string.
Returns
-------
pred
A string containing the predictor's guess of the text in the image.
"""
encoded_image = util.encode_b64_image(image)
headers = {"Content-type": "application/json"}
payload = json.dumps({"image": "data:image/png;base64," + encoded_image})
response = requests.post(self.url, data=payload, headers=headers)
pred = response.json()["pred"]
return pred
def _log_inference(self, pred, metrics):
for key, value in metrics.items():
logging.info(f"METRIC {key} {value}")
logging.info(f"PRED >begin\n{pred}\nPRED >end")
def _load_readme(with_logging=False):
with open(README) as f:
lines = f.readlines()
if not with_logging:
lines = lines[: lines.index("\n")]
readme = "".join(lines)
return readme
def _make_parser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_url",
default=None,
type=str,
help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.",
)
parser.add_argument(
"--port",
default=DEFAULT_PORT,
type=int,
help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.",
)
return parser
if __name__ == "__main__":
parser = _make_parser()
args = parser.parse_args()
main(args)
================================================
FILE: lab07/app_gradio/tests/test_app.py
================================================
import json
import os
import requests
from app_gradio import app
from text_recognizer import util
os.environ["CUDA_VISIBLE_DEVICES"] = ""
TEST_IMAGE = "text_recognizer/tests/support/paragraphs/a01-077.png"
def test_local_run():
"""A quick test to make sure we can build the app and ping the API locally."""
backend = app.PredictorBackend()
frontend = app.make_frontend(fn=backend.run)
# run the UI without blocking
frontend.launch(share=False, prevent_thread_lock=True)
local_url = frontend.local_url
get_response = requests.get(local_url)
assert get_response.status_code == 200, get_response.content
image_b64 = util.encode_b64_image(util.read_image_pil(TEST_IMAGE))
local_api = f"{local_url}api/predict"
headers = {"Content-Type": "application/json"}
payload = json.dumps({"data": ["data:image/png;base64," + image_b64]})
post_response = requests.post(local_api, data=payload, headers=headers)
assert post_response.status_code == 200, post_response.content
================================================
FILE: lab07/notebooks/lab01_pytorch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab07/notebooks/lab03_transformers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" \", \"\")\n",
"\n",
"idx = random.randint(0, len(xs))\n",
"\n",
"print(show(ys[idx]))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4dT3UCNzTsoc"
},
"source": [
"The `ResnetTransformer` model can run on this data\n",
"if passed the `.config`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXL-vIGRr86D"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMxa-oWyT01E"
},
"source": [
"Our models are now big enough\n",
"that we want to make use of GPU acceleration\n",
"as much as we can,\n",
"even when working on single inputs,\n",
"so let's cast to the GPU if we have one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-YyUM8LgvW0w"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"rnt.to(device); xs = xs.to(device); ys = ys.to(device);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-E3UdD4zUJi"
},
"source": [
"First, let's just pass it through the ResNet encoder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-LUUtlvaxrvg"
},
"outputs": [],
"source": [
"resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n",
" # resnet is designed for RGB images, so we replicate the input across channels 3 times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eimgJ5dnywjg"
},
"outputs": [],
"source": [
"resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n",
"plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n",
"plt.axis(\"off\"); plt.colorbar(fraction=0.05);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These embeddings, though generated by random, untrained weights,\n",
"are not entirely useless.\n",
"\n",
"Before neural networks could be effectively\n",
"trained end to end,\n",
"they were often used with frozen random weights\n",
"eveywhere except the final layer\n",
"(see e.g.\n",
"[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n",
"[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n",
"these methods were still competitive, and\n",
"[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n",
"provide a\n",
"[theoretical basis](https://arxiv.org/abs/2011.14522)\n",
"for understanding their performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ye6pW0ETzw2A"
},
"source": [
"The final result, though, is repetitive gibberish --\n",
"at the bare minimum, we need to train the unembedding/readout layer\n",
"in order to get reasonable text."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our architecture includes randomization with dropout,\n",
"so repeated runs of the cell below will generate different outcomes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu3Pa7gLsFMo"
},
"outputs": [],
"source": [
"preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gvCXUbskv6XM"
},
"outputs": [],
"source": [
"print(show(preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without teacher forcing, runtime is also variable from iteration to iteration --\n",
"the model stops when it generates an \"end sequence\" or padding token,\n",
"which is not deterministic thanks to the dropout layers.\n",
"For similar reasons, runtime is variable across inputs.\n",
"\n",
"The variable runtime of autoregressive generation\n",
"is also not great for scaling.\n",
"In a distributed setting, as required for large scale,\n",
"forward passes need to be synced across devices,\n",
"and if one device is generating a batch of much longer sequences,\n",
"it will cause all the others to idle while they wait on it to finish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t76MSVRXV0V7"
},
"source": [
"Let's turn our model into a `TransformerLitModel`\n",
"so we can run with teacher forcing.\n",
"\n",
"> You may be wondering:\n",
" why isn't teacher forcing part of the PyTorch module?\n",
" In general, the `LightningModule`\n",
" should encapsulate things that are needed in training, validation, and testing\n",
" but not during inference.\n",
" The teacher forcing trick fits this paradigm,\n",
" even though it's so critical to what makes Transformers powerful. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qrHRKHowdDi"
},
"outputs": [],
"source": [
"import text_recognizer.lit_models\n",
"\n",
"lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlNaFqR50Oid"
},
"source": [
"Now we can use `.teacher_forward` if we also provide the target `ys`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lpZdqXS5wn0F"
},
"outputs": [],
"source": [
"forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Zx9SmsN0QLT"
},
"source": [
"This may not run faster than the `rnt.forward`,\n",
"since generations are always the maximum possible length,\n",
"but runtimes and output lengths are deterministic and constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tu-XNYpi0Qvi"
},
"source": [
"Forcing doesn't necessarily make our predictions better.\n",
"They remain highly repetitive gibberish."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JcEgify9w0sv"
},
"outputs": [],
"source": [
"forcing_preds = torch.argmax(forcing_outs, dim=0)\n",
"\n",
"print(show(forcing_preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xn6GGNzc9a3o"
},
"source": [
"## Training the `ResNetTransformer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uvZYsuSyWUXe"
},
"source": [
"We're finally ready to train this model on full paragraphs of handwritten text!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cJwC7b720Sd"
},
"source": [
"This is a more serious model --\n",
"it's the one we use in the\n",
"[deployed TextRecognizer application](http://fsdl.me/app).\n",
"It's much larger than the models we've seen this far,\n",
"so it can easily outstrip available compute resources,\n",
"in particular GPU memory.\n",
"\n",
"To help, we use\n",
"[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n",
"which shrinks the size of most of our floats by half,\n",
"which reduces memory consumption and can speed up computation.\n",
"\n",
"If your GPU has less than 8GB of available RAM,\n",
"you'll see a \"CUDA out of memory\" `RuntimeError`,\n",
"which is something of a\n",
"[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n",
"In this case, you can resolve it by reducing the `--batch_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w1mXlhfy04Nm"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"gpus = int(torch.cuda.is_available())\n",
"\n",
"if gpus:\n",
" !nvidia-smi\n",
"else:\n",
" print(\"watch out! working with this model on a typical CPU is not feasible\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "os1vW1rPZ1dy"
},
"source": [
"Even with an okay GPU, like a\n",
"[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n",
"a single epoch of training can take over 10 minutes to run.\n",
"We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n",
"but you can remove those flags to see what full training looks like."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vnF6dWFn4JlZ"
},
"source": [
"It can take a long time (overnight)\n",
"to train this model to decent performance on a single GPU,\n",
"so we'll focus on other pieces for the exercises.\n",
"\n",
"> At the time of writing in mid-2022, the cheapest readily available option\n",
"for training this model to decent performance on this dataset with this codebase\n",
"comes out around $10, using\n",
"[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n",
"See, for example,\n",
"[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n",
"and associated experiment.\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HufjdUZN0t4l",
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"# above %%magic times the cell, useful as a poor man's profiler\n",
"\n",
"%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n",
" --gpus={gpus} --batch_size 16 --precision 16 \\\n",
" --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L6fQ93ju3Iku"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udb1Ekjx3L63"
},
"source": [
"### 🌟 Try out gradient accumulation and other \"training tricks\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpqViB4p3Wfb"
},
"source": [
"Larger batches are helpful not only for increasing parallelization\n",
"and amortizing fixed costs\n",
"but also for getting more reliable gradients.\n",
"Larger batches give gradients with less noise\n",
"and to a point, less gradient noise means faster convergence.\n",
"\n",
"But larger batches result in larger tensors,\n",
"which take up more GPU memory,\n",
"a resource that is tightly constrained\n",
"and device-dependent.\n",
"\n",
"Does that mean we are limited in the quality of our gradients\n",
"due to our machine size?\n",
"\n",
"Not entirely:\n",
"look up the `--accumulate_grad_batches`\n",
"argument to the `pl.Trainer`.\n",
"You should be able to understand why\n",
"it makes it possible to compute the same gradients\n",
"you would find for a batch of size `k * N`\n",
"on a machine that can only run batches up to size `N`.\n",
"\n",
"Accumulating gradients across batches is among the\n",
"[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n",
"Try some of them out!\n",
"Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2vtkmX830y3"
},
"source": [
"### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n",
"\n",
"While training this model to actually fit the whole dataset is infeasible\n",
"as a short exercise on commodity hardware,\n",
"it's practical to train this model to memorize a batch of 16 examples.\n",
"\n",
"Passing `--overfit_batches 1` flag limits the number of training batches to 1\n",
"and turns off\n",
"[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n",
"so that in each epoch, the model just sees the same single batch of data over and over again.\n",
"\n",
"At first, try training the model to a loss of `2.5` --\n",
"it should be doable in 100 epochs or less,\n",
"which is just a few minutes on a commodity GPU.\n",
"\n",
"Once you've got that working,\n",
"crank up the number of epochs by a factor of 10\n",
"and confirm that the loss continues to go down.\n",
"\n",
"Some tips:\n",
"\n",
"- Use `--limit_test_batches 0` to turn off testing.\n",
"We don't need it because we don't care about generalization\n",
"and it's relatively slow because it runs the model autoregressively.\n",
"\n",
"- Use `--help` and look through the model class args\n",
"to find the arguments used to reduce model size.\n",
"\n",
"- By default, there's lots of regularization to prevent overfitting.\n",
"Look through the args for the model class and data class\n",
"for regularization knobs to turn off or down."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab03_transformers.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab07/notebooks/lab04_experiments.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab07/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab07/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" "]
self.ignore_tokens = [self.start_index, self.end_index, self.padding_index]
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
================================================
FILE: lab07/text_recognizer/lit_models/metrics.py
================================================
"""Special-purpose metrics for tracking our model performance."""
from typing import Sequence
import torch
import torchmetrics
class CharacterErrorRate(torchmetrics.CharErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore
preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()]
targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()]
super().update(preds_l, targets_l)
def test_character_error_rate():
metric = CharacterErrorRate([0, 1])
X = torch.tensor(
[
[0, 2, 2, 3, 3, 1], # error will be 0
[0, 2, 1, 1, 1, 1], # error will be .75
[0, 2, 2, 4, 4, 1], # error will be .5
]
)
Y = torch.tensor(
[
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
]
)
metric(X, Y)
assert metric.compute() == sum([0, 0.75, 0.5]) / 3
if __name__ == "__main__":
test_character_error_rate()
================================================
FILE: lab07/text_recognizer/lit_models/transformer.py
================================================
"""An encoder-decoder Transformer model"""
from typing import List, Sequence
import torch
from .base import BaseImageToTextLitModel
from .util import replace_after
class TransformerLitModel(BaseImageToTextLitModel):
"""
Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module.
The module must implement an encode and decode method, and the forward method
should be the forward pass during production inference.
"""
def __init__(self, model, args=None):
super().__init__(model, args)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index)
def forward(self, x):
return self.model(x)
def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Uses provided sequence y as guide for non-autoregressive encoding-decoding of x.
Parameters
----------
x
Batch of images to be encoded. See self.model.encode for shape information.
y
Batch of ground truth output sequences.
Returns
-------
torch.Tensor
(B, C, Sy) logits
"""
x = self.model.encode(x)
output = self.model.decode(x, y) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("train/loss", loss)
outputs = {"loss": loss}
if self.is_logged_batch():
preds = self.get_preds(logits)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs})
return outputs
def validation_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("validation/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def test_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("test/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def map(self, ks: Sequence[int], ignore: bool = True) -> str:
"""Maps an iterable of integers to a string using the lit model's mapping."""
if ignore:
return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens])
else:
return "".join([self.mapping[k] for k in ks])
def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]:
"""Maps a list of lists of integers to a list of strings using the lit model's mapping."""
return [self.map(k, ignore) for k in ks]
def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor:
"""Converts logit-like Tensors into prediction indices, optionally overwritten after end token index.
Parameters
----------
logitlikes
(B, C, Sy) Tensor with classes as second dimension. The largest value is the one
whose index we will return. Logits, logprobs, and probs are all acceptable.
replace_after_end
Whether to replace values after the first appearance of the end token with the padding token.
Returns
-------
torch.Tensor
(B, Sy) Tensor of integers in [0, C-1] representing predictions.
"""
raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy)
if replace_after_end:
return replace_after(raw, self.end_index, self.padding_index) # (B, Sy)
else:
return raw # (B, Sy)
================================================
FILE: lab07/text_recognizer/lit_models/util.py
================================================
from typing import Union
import torch
def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:
"""Return indices of first appearance of element in x, collapsing along dim.
Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9
Parameters
----------
x
One or two-dimensional Tensor to search for element.
element
Item to search for inside x.
dim
Dimension of Tensor to collapse over.
Returns
-------
torch.Tensor
Indices where element occurs in x. If element is not found,
return length of x along dim. One dimension smaller than x.
Raises
------
ValueError
if x is not a 1 or 2 dimensional Tensor
Examples
--------
>>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)
tensor([2, 1, 3, 0])
>>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)
tensor(0)
"""
if x.dim() > 2 or x.dim() == 0:
raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}")
matches = x == element
first_appearance_mask = (matches.cumsum(dim) == 1) & matches
does_match, match_index = first_appearance_mask.max(dim)
first_inds = torch.where(does_match, match_index, x.shape[dim])
return first_inds
def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor:
"""Replace all values in each row of 2d Tensor x after the first appearance of element with replace.
Parameters
----------
x
Two-dimensional Tensor (shape denoted (B, S)) to replace values in.
element
Item to search for inside x.
replace
Item that replaces entries that appear after element.
Returns
-------
outs
New Tensor of same shape as x with values after element replaced.
Examples
--------
>>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4)
tensor([[1, 2, 3],
[2, 3, 4],
[1, 1, 1],
[3, 4, 4]])
"""
first_appearances = first_appearance(x, element, dim=1) # (B,)
indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,)
outs = torch.where(
indices[None, :] <= first_appearances[:, None], # if index is before first appearance
x, # return the value from x
replace, # otherwise, return the replacement value
)
return outs # (B, S)
================================================
FILE: lab07/text_recognizer/metadata/emnist.py
================================================
from pathlib import Path
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
NUM_SPECIAL_TOKENS = 4
INPUT_SHAPE = (28, 28)
DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor()
OUTPUT_DIMS = (1,)
MAPPING = [
"",
" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab07/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab07/text_recognizer/metadata/iam.py
================================================
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam"
EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be
LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates
================================================
FILE: lab07/text_recognizer/metadata/iam_lines.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines"
IMAGE_SCALE_FACTOR = 2
CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate
IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR
IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (89, 1)
MAPPING = emnist.MAPPING
================================================
FILE: lab07/text_recognizer/metadata/iam_paragraphs.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN]
# must match IMAGE_SCALE_FACTOR for IAMLines to be compatible with synthetic paragraphs
IMAGE_SCALE_FACTOR = 2
IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH)
MAX_LABEL_LENGTH = 682
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1)
================================================
FILE: lab07/text_recognizer/metadata/iam_synthetic_paragraphs.py
================================================
import text_recognizer.metadata.iam_paragraphs as iam_paragraphs
import text_recognizer.metadata.shared as shared
NEW_LINE_TOKEN = iam_paragraphs.NEW_LINE_TOKEN
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs"
EXPECTED_BATCH_SIZE = 64
EXPECTED_GPUS = 8
EXPECTED_STEPS = 40
# set the dataset's length based on parameters during typical training
DATASET_LEN = EXPECTED_BATCH_SIZE * EXPECTED_GPUS * EXPECTED_STEPS
================================================
FILE: lab07/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab07/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab07/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
from .resnet_transformer import ResnetTransformer
from .line_cnn_transformer import LineCNNTransformer
================================================
FILE: lab07/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab07/text_recognizer/models/line_cnn.py
================================================
"""Basic building blocks for convolutional models over lines of text."""
import argparse
import math
from typing import Any, Dict, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
# Common type hints
Param2D = Union[int, Tuple[int, int]]
CONV_DIM = 32
FC_DIM = 512
FC_DROPOUT = 0.2
WINDOW_WIDTH = 16
WINDOW_STRIDE = 8
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: Param2D = 3,
stride: Param2D = 1,
padding: Param2D = 1,
) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class LineCNN(nn.Module):
"""
Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits
"""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.args = vars(args) if args is not None else {}
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
_C, H, _W = data_config["input_dims"]
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
# Input is (1, H, W)
self.convs = nn.Sequential(
ConvBlock(1, conv_dim),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim, stride=2),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim * 2, stride=2),
ConvBlock(conv_dim * 2, conv_dim * 2),
ConvBlock(conv_dim * 2, conv_dim * 4, stride=2),
ConvBlock(conv_dim * 4, conv_dim * 4),
ConvBlock(
conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0
),
)
self.fc1 = nn.Linear(fc_dim, fc_dim)
self.dropout = nn.Dropout(fc_dropout)
self.fc2 = nn.Linear(fc_dim, self.num_classes)
self._init_weights()
def _init_weights(self):
"""
Initialize weights in a better way than default.
See https://github.com/pytorch/pytorch/issues/18182
"""
for m in self.modules():
if type(m) in {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}:
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the LineCNN to a black-and-white input image.
Parameters
----------
x
(B, 1, H, W) input image
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and self.window_width
C is self.num_classes
"""
_B, _C, _H, _W = x.shape
x = self.convs(x) # (B, FC_DIM, 1, Sx)
x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM)
x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM)
x = self.dropout(x)
x = self.fc2(x) # (B, S, C)
x = x.permute(0, 2, 1) # -> (B, C, S)
if self.limit_output_length:
x = x[:, :, : self.output_length]
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab07/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab07/text_recognizer/models/line_cnn_transformer.py
================================================
"""Model that combines a LineCNN with a Transformer model for text prediction."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .line_cnn import LineCNN
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding
TF_DIM = 256
TF_FC_DIM = 256
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
class LineCNNTransformer(nn.Module):
"""Process the line through a CNN and process the resulting sequence with a Transformer decoder."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# Instantiate LineCNN with "num_classes" set to self.dim
data_config_for_line_cnn = {**data_config}
data_config_for_line_cnn["mapping"] = list(range(self.dim))
self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args)
# LineCNN outputs (B, E, S) log probs, with E == dim
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.pos_encoder = PositionalEncoding(d_model=self.dim)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(Sx, B, E) logits
"""
x = self.line_cnn(x) # (B, E, Sx)
x = x * math.sqrt(self.dim)
x = x.permute(2, 0, 1) # (Sx, B, E)
x = self.pos_encoder(x) # (Sx, B, E)
return x
def decode(self, x, y):
"""Decode a batch of encoded images x using preceding ground truth y.
Parameters
----------
x
(Sx, B, E) image encoded as a sequence
y
(B, Sy) with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) logits
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict sequences of tokens from input images auto-regressively.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
# Set all tokens after end token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
@staticmethod
def add_to_argparse(parser):
LineCNN.add_to_argparse(parser)
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab07/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab07/text_recognizer/models/resnet_transformer.py
================================================
"""Model combining a ResNet with a Transformer for image-to-sequence tasks."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
import torchvision
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage
TF_DIM = 256
TF_FC_DIM = 1024
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
RESNET_DIM = 512 # hard-coded
class ResnetTransformer(nn.Module):
"""Pass an image through a Resnet and decode the resulting embedding with a Transformer."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
self.mapping = data_config["mapping"]
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# ## Encoder part - should output vector sequence of length self.dim per sample
resnet = torchvision.models.resnet18(weights=None)
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
# encoder_projection will output (B, dim, _H, _W) logits
self.enc_pos_encoder = PositionalEncodingImage(
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
) # Max (Ho, Wo)
# ## Decoder part
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Autoregressively produce sequences of labels from input images.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
output_tokens
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1] # Set the last output token
# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
break
# Set all tokens after end or padding token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if self.encoder_projection.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
(Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps
"""
_B, C, _H, _W = x.shape
if C == 1:
x = x.repeat(1, 3, 1, 1)
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
return x
def decode(self, x, y):
"""Decode a batch of encoded images x with guiding sequences y.
During autoregressive inference, the guiding sequence will be previous predictions.
During training, the guiding sequence will be the ground truth.
Parameters
----------
x
(Sx, B, E) images encoded as sequences of embeddings
y
(B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) batch of logit sequences
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.dec_pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab07/text_recognizer/models/transformer_util.py
================================================
"""Position Encoding and other utilities for Transformers."""
import math
import torch
from torch import Tensor
import torch.nn as nn
class PositionalEncodingImage(nn.Module):
"""
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
"""
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None:
super().__init__()
self.d_model = d_model
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
return pe
def forward(self, x: Tensor) -> Tensor:
"""pytorch.nn.module.forward"""
# x.shape = (B, d_model, H, W)
assert x.shape[1] == self.pe.shape[0] # type: ignore
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
return x
class PositionalEncoding(torch.nn.Module):
"""Classic Attention-is-all-you-need positional encoding."""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape = (S, B, d_model)
assert x.shape[2] == self.pe.shape[2] # type: ignore
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
"""Generate a triangular (size, size) mask."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
================================================
FILE: lab07/text_recognizer/paragraph_text_recognizer.py
================================================
"""Detects a paragraph of text in an input image.
Example usage as a script:
python text_recognizer/paragraph_text_recognizer.py \
text_recognizer/tests/support/paragraphs/a01-077.png
python text_recognizer/paragraph_text_recognizer.py \
https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png
"""
import argparse
from pathlib import Path
from typing import Sequence, Union
from PIL import Image
import torch
from text_recognizer import util
from text_recognizer.stems.paragraph import ParagraphStem
STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "paragraph-text-recognizer"
MODEL_FILE = "model.pt"
class ParagraphTextRecognizer:
"""Recognizes a paragraph of text in an image."""
def __init__(self, model_path=None):
if model_path is None:
model_path = STAGED_MODEL_DIRNAME / MODEL_FILE
self.model = torch.jit.load(model_path)
self.mapping = self.model.mapping
self.ignore_tokens = self.model.ignore_tokens
self.stem = ParagraphStem()
@torch.no_grad()
def predict(self, image: Union[str, Path, Image.Image]) -> str:
"""Predict/infer text in input image (which can be a file path or url)."""
image_pil = image
if not isinstance(image, Image.Image):
image_pil = util.read_image_pil(image, grayscale=True)
image_tensor = self.stem(image_pil).unsqueeze(axis=0)
y_pred = self.model(image_tensor)[0]
pred_str = convert_y_label_to_string(y=y_pred, mapping=self.mapping, ignore_tokens=self.ignore_tokens)
return pred_str
def convert_y_label_to_string(y: torch.Tensor, mapping: Sequence[str], ignore_tokens: Sequence[int]) -> str:
return "".join([mapping[i] for i in y if i not in ignore_tokens])
def main():
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
parser.add_argument(
"filename",
type=str,
help="Name for an image file. This can be a local path, a URL, a URI from AWS/GCP/Azure storage, an HDFS path, or any other resource locator supported by the smart_open library.",
)
args = parser.parse_args()
text_recognizer = ParagraphTextRecognizer()
pred_str = text_recognizer.predict(args.filename)
print(pred_str)
if __name__ == "__main__":
main()
================================================
FILE: lab07/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab07/text_recognizer/stems/line.py
================================================
import random
from PIL import Image
from torchvision import transforms
import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.stems.image import ImageStem
class LineStem(ImageStem):
"""A stem for handling images containing a line of text."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.5, 1)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"translate": (0, 0.05),
"scale": (0.4, 1.1),
"shear": (-40, 50),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
if augment:
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
)
class IAMLineStem(ImageStem):
"""A stem for handling images containing lines of text from the IAMLines dataset."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
def embed_crop(crop, augment=augment):
# crop is PIL.image of dtype="L" (so values range from 0 -> 255)
image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
# Resize crop
crop_width, crop_height = crop.size
new_crop_height = metadata.IMAGE_HEIGHT
new_crop_width = int(new_crop_height * (crop_width / crop_height))
if augment:
# Add random stretching
new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR)
# Embed in the image
x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
y = metadata.IMAGE_HEIGHT - new_crop_height
image.paste(crop_resized, (x, y))
return image
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.8, 1.6)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 1,
"shear": (-30, 20),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
pil_transforms_list = [transforms.Lambda(embed_crop)]
if augment:
pil_transforms_list += [
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
self.pil_transforms = transforms.Compose(pil_transforms_list)
================================================
FILE: lab07/text_recognizer/stems/paragraph.py
================================================
"""IAMParagraphs Stem class."""
import torchvision.transforms as transforms
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.stems.image import ImageStem
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
class ParagraphStem(ImageStem):
"""A stem for handling images that contain a paragraph of text."""
def __init__(
self,
augment=False,
color_jitter_kwargs=None,
random_affine_kwargs=None,
random_perspective_kwargs=None,
gaussian_blur_kwargs=None,
sharpness_kwargs=None,
):
super().__init__()
if not augment:
self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)])
else:
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"shear": 6,
"scale": (0.95, 1),
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if random_perspective_kwargs is None:
random_perspective_kwargs = {
"distortion_scale": 0.2,
"p": 0.5,
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if gaussian_blur_kwargs is None:
gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
if sharpness_kwargs is None:
sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
# IMAGE_SHAPE is (576, 640)
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomCrop(
size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant"
),
transforms.RandomAffine(**random_affine_kwargs),
transforms.RandomPerspective(**random_perspective_kwargs),
transforms.GaussianBlur(**gaussian_blur_kwargs),
transforms.RandomAdjustSharpness(**sharpness_kwargs),
]
)
================================================
FILE: lab07/text_recognizer/tests/test_callback_utils.py
================================================
"""Tests for the text_recognizer.callbacks.util module."""
import random
import string
import tempfile
import pytorch_lightning as pl
from text_recognizer.callbacks.util import check_and_warn
def test_check_and_warn_simple():
"""Test the success and failure in the case of a simple class we control."""
class Foo:
pass # a class with no special attributes
letters = string.ascii_lowercase
random_attribute = "".join(random.choices(letters, k=10))
assert check_and_warn(Foo(), random_attribute, "random feature")
assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects")
def test_check_and_warn_tblogger():
"""Test that we return a truthy value when trying to log tables with TensorBoard.
We added check_and_warn in order to prevent a crash if this happens.
"""
tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory())
assert check_and_warn(tblogger, "log_table", "tables")
def test_check_and_warn_wandblogger():
"""Test that we return a falsy value when we try to log tables with W&B.
In adding check_and_warn, we don't want to block the feature in the happy path.
"""
wandblogger = pl.loggers.WandbLogger(anonymous=True)
assert not check_and_warn(wandblogger, "log_table", "tables")
================================================
FILE: lab07/text_recognizer/tests/test_iam.py
================================================
"""Test for data.iam module."""
from text_recognizer.data.iam import IAM
def test_iam_parsed_lines():
"""Tests that we retrieve the same number of line labels and line image cropregions."""
iam = IAM()
iam.prepare_data()
for iam_id in iam.all_ids:
assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id])
def test_iam_data_splits():
"""Fails when any identifiers are shared between training, test, or validation."""
iam = IAM()
iam.prepare_data()
assert not set(iam.train_ids) & set(iam.validation_ids)
assert not set(iam.train_ids) & set(iam.test_ids)
assert not set(iam.validation_ids) & set(iam.test_ids)
================================================
FILE: lab07/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def encode_b64_image(image, format="png"):
"""Encode a PIL image as a base64 string."""
_buffer = BytesIO() # bytes that live in memory
image.save(_buffer, format=format) # but which we write to like a file
encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8")
return encoded_image
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab07/training/__init__.py
================================================
================================================
FILE: lab07/training/cleanup_artifacts.py
================================================
"""Removes artifacts from projects and runs.
Artifacts are binary files that we want to track
and version but don't want to include in git,
generally because they are too large,
because they don't have meaningful diffs,
or because they change more quickly than code.
During development, we often generate artifacts
that we don't really need, e.g. model weights for
an overfitting test run. Space on artifact storage
is generally very large, but it is limited,
so we should occasionally delete unneeded artifacts
to reclaim some of that space.
For usage help, run
python training/cleanup_artifacts.py --help
"""
import argparse
import wandb
api = wandb.Api()
DEFAULT_PROJECT = "fsdl-text-recognizer-2022-training"
DEFAULT_ENTITY = api.default_entity
def _setup_parser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--entity",
type=str,
default=None,
help="The entity from which to remove artifacts. Provide the value DEFAULT "
+ f"to use the default WANDB_ENTITY, which is currently {DEFAULT_ENTITY}.",
)
parser.add_argument(
"--project",
type=str,
default=DEFAULT_PROJECT,
help=f"The project from which to remove artifacts. Default is {DEFAULT_PROJECT}",
)
parser.add_argument(
"--run_ids",
type=str,
default=None,
nargs="*",
help="One or more run IDs from which to remove artifacts. Default is None.",
)
parser.add_argument(
"--run_name_res",
type=str,
default=None,
nargs="*",
help="One or more regular expressions to use to select runs (by display name) from which to remove artifacts. See wandb.Api.runs documentation for details on the syntax. Beware that this is a footgun and consider using interactively with --dryrun and -v. Default is None.",
metavar="RUN_NAME_REGEX",
)
flags = parser.add_mutually_exclusive_group()
flags.add_argument("--all", action="store_true", help="Delete all artifacts from selected runs.")
flags.add_argument(
"--no-alias", action="store_true", help="Delete all artifacts without an alias from selected runs."
)
flags.add_argument(
"--aliases",
type=str,
nargs="*",
help="Delete artifacts that have any of the aliases from the provided list from selected runs.",
)
parser.add_argument(
"-v",
action="store_true",
dest="verbose",
help="Display information about targeted entities, projects, runs, and artifacts.",
)
parser.add_argument(
"--dryrun",
action="store_true",
help="Select artifacts without deleting them and display which artifacts were selected.",
)
return parser
def main(args):
entity = _get_entity_from(args)
project_path = f"{entity}/{args.project}"
runs = _get_runs(project_path, args.run_ids, args.run_name_res, verbose=args.verbose)
artifact_selector = _get_selector_from(args)
protect_aliases = args.no_alias # avoid deletion of any aliased artifacts
for run in runs:
clean_run_artifacts(
run, selector=artifact_selector, protect_aliases=protect_aliases, verbose=args.verbose, dryrun=args.dryrun
)
def clean_run_artifacts(run, selector, protect_aliases=True, verbose=False, dryrun=True):
artifacts = run.logged_artifacts()
for artifact in artifacts:
if selector(artifact):
remove_artifact(artifact, protect_aliases=protect_aliases, verbose=verbose, dryrun=dryrun)
def remove_artifact(artifact, protect_aliases, verbose=False, dryrun=True):
project, entity, id = artifact.project, artifact.entity, artifact.id
type, aliases = artifact.type, artifact.aliases
if verbose or dryrun:
print(f"selecting for deletion artifact {project}/{entity}/{id} of type {type} with aliases {aliases}")
if not dryrun:
artifact.delete(delete_aliases=not protect_aliases)
def _get_runs(project_path, run_ids=None, run_name_res=None, verbose=False):
if run_ids is None:
run_ids = []
if run_name_res is None:
run_name_res = []
runs = []
for run_id in run_ids:
runs.append(_get_run_by_id(project_path, run_id, verbose=verbose))
for run_name_re in run_name_res:
runs += _get_runs_by_name_re(project_path, run_name_re, verbose=verbose)
return runs
def _get_run_by_id(project_path, run_id, verbose=False):
path = f"{project_path}/{run_id}"
run = api.run(path)
if verbose:
print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}")
return run
def _get_runs_by_name_re(project_path, run_name_re, verbose=False):
matching_runs = api.runs(path=project_path, filters={"display_name": {"$regex": run_name_re}})
if verbose:
for run in matching_runs:
print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}")
return matching_runs
def _get_selector_from(args, verbose=False):
if args.all:
if verbose:
print("removing all artifacts from matching runs")
return lambda _: True
if args.no_alias:
if verbose:
print("removing all artifacts with no aliases from matching runs")
return lambda artifact: artifact.aliases == []
if args.aliases:
if verbose:
print(f"removing all artifacts with any of {args.aliases} in aliases from matching runs")
return lambda artifact: any(alias in artifact.aliases for alias in args.aliases)
if verbose:
print("removing no artifacts matching runs")
return lambda _: False
def _get_entity_from(args, verbose=False):
entity = args.entity
if entity is None:
raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.")
elif entity == "DEFAULT":
entity = DEFAULT_ENTITY
if verbose:
print(f"using default entity {entity}")
else:
if verbose:
print(f"using entity {entity}")
return entity
if __name__ == "__main__":
parser = _setup_parser()
args = parser.parse_args()
main(args)
================================================
FILE: lab07/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import callbacks as cb
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--wandb",
action="store_true",
default=False,
help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.",
)
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.loss == "transformer":
lit_model_class = lit_models.TransformerLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/cer":
filename_format += "-validation.cer={validation/cer:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.wandb:
logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train")
logger.watch(model, log_freq=max(100, args.log_every_n_steps))
logger.log_hyperparams(vars(args))
experiment_dir = logger.experiment.dir
callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
if args.wandb and args.loss in ("transformer",):
callbacks.append(cb.ImageToTextLogger())
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
if args.profile:
sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0)
profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir)
profiler.STEP_FUNCTIONS = {"training_step"} # only profile training
else:
profiler = pl.profiler.PassThroughProfiler()
trainer.profiler = profiler
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
if args.wandb:
rank_zero_info("Best model also uploaded to W&B ")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab07/training/stage_model.py
================================================
"""Stages a model for use in production.
If based on a checkpoint, the model is converted to torchscript, saved locally,
and uploaded to W&B.
If based on a model that is already converted and uploaded, the model file is downloaded locally.
For details on how the W&B artifacts backing the checkpoints and models are handled,
see the documenation for stage_model.find_artifact.
"""
import argparse
from pathlib import Path
import tempfile
import torch
import wandb
from text_recognizer.lit_models import TransformerLitModel
from training.util import setup_data_and_model_from_args
# these names are all set by the pl.loggers.WandbLogger
MODEL_CHECKPOINT_TYPE = "model"
BEST_CHECKPOINT_ALIAS = "best"
MODEL_CHECKPOINT_PATH = "model.ckpt"
LOG_DIR = Path("training") / "logs"
STAGED_MODEL_TYPE = "prod-ready" # we can choose the name of this type, and ideally it's different from checkpoints
STAGED_MODEL_FILENAME = "model.pt" # standard nomenclature; pytorch_model.bin is also used
PROJECT_ROOT = Path(__file__).resolve().parents[1]
LITMODEL_CLASS = TransformerLitModel
api = wandb.Api()
DEFAULT_ENTITY = api.default_entity
DEFAULT_FROM_PROJECT = "fsdl-text-recognizer-2022-training"
DEFAULT_TO_PROJECT = "fsdl-text-recognizer-2022-training"
DEFAULT_STAGED_MODEL_NAME = "paragraph-text-recognizer"
PROD_STAGING_ROOT = PROJECT_ROOT / "text_recognizer" / "artifacts"
def main(args):
prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name
prod_staging_directory.mkdir(exist_ok=True, parents=True)
entity = _get_entity_from(args)
# if we're just fetching an already compiled model
if args.fetch:
# find it and download it
staged_model = f"{entity}/{args.from_project}/{args.staged_model_name}:latest"
artifact = download_artifact(staged_model, prod_staging_directory)
print_info(artifact)
return # and we're done
# otherwise, we'll need to download the weights, compile the model, and save it
with wandb.init(
job_type="stage", project=args.to_project, dir=LOG_DIR
): # log staging to W&B so prod and training are connected
# find the model checkpoint and retrieve its artifact name and an api handle
ckpt_at, ckpt_api = find_artifact(
entity, args.from_project, type=MODEL_CHECKPOINT_TYPE, alias=args.ckpt_alias, run=args.run
)
# get the run that produced that checkpoint
logging_run = get_logging_run(ckpt_api)
print_info(ckpt_api, logging_run)
metadata = get_checkpoint_metadata(logging_run, ckpt_api)
# create an artifact for the staged, deployable model
staged_at = wandb.Artifact(args.staged_model_name, type=STAGED_MODEL_TYPE, metadata=metadata)
with tempfile.TemporaryDirectory() as tmp_dir:
# download the checkpoint to a temporary directory
download_artifact(ckpt_at, tmp_dir)
# reload the model from that checkpoint
model = load_model_from_checkpoint(metadata, directory=tmp_dir)
# save the model to torchscript in the staging directory
save_model_to_torchscript(model, directory=prod_staging_directory)
# upload the staged model so it can be downloaded elsewhere
upload_staged_model(staged_at, from_directory=prod_staging_directory)
def find_artifact(entity: str, project: str, type: str, alias: str, run=None):
"""Finds the artifact of a given type with a given alias under the entity and project.
Parameters
----------
entity
The name of the W&B entity under which the artifact is logged.
project
The name of the W&B project under which the artifact is logged.
type
The name of the type of the artifact.
alias : str
The alias for this artifact. This alias must be unique within the
provided type for the run, if provided, or for the project,
if the run is not provided.
run : str
Optionally, the run in which the artifact is located.
Returns
-------
Tuple[path, artifact]
An identifying path and an API handle for a matching artifact.
"""
if run is not None:
path = _find_artifact_run(entity, project, type=type, run=run, alias=alias)
else:
path = _find_artifact_project(entity, project, type=type, alias=alias)
return path, api.artifact(path)
def get_logging_run(artifact):
api_run = artifact.logged_by()
return api_run
def print_info(artifact, run=None):
if run is None:
run = get_logging_run(artifact)
full_artifact_name = f"{artifact.entity}/{artifact.project}/{artifact.name}"
print(f"Using artifact {full_artifact_name}")
artifact_url_prefix = f"https://wandb.ai/{artifact.entity}/{artifact.project}/artifacts/{artifact.type}"
artifact_url_suffix = f"{artifact.name.replace(':', '/')}"
print(f"View at URL: {artifact_url_prefix}/{artifact_url_suffix}")
print(f"Logged by {run.name} -- {run.project}/{run.entity}/{run.id}")
print(f"View at URL: {run.url}")
def get_checkpoint_metadata(run, checkpoint):
config = run.config
out = {"config": config}
try:
ckpt_filename = checkpoint.metadata["original_filename"]
out["original_filename"] = ckpt_filename
metric_key = checkpoint.metadata["ModelCheckpoint"]["monitor"]
metric_score = checkpoint.metadata["score"]
out[metric_key] = metric_score
except KeyError:
pass
return out
def download_artifact(artifact_path, target_directory):
"""Downloads the artifact at artifact_path to the target directory."""
if wandb.run is not None: # if we are inside a W&B run, track that we used this artifact
artifact = wandb.use_artifact(artifact_path)
else: # otherwise, just download the artifact via the API
artifact = api.artifact(artifact_path)
artifact.download(root=target_directory)
return artifact
def load_model_from_checkpoint(ckpt_metadata, directory):
config = ckpt_metadata["config"]
args = argparse.Namespace(**config)
_, model = setup_data_and_model_from_args(args)
# load LightningModule from checkpoint
pth = Path(directory) / MODEL_CHECKPOINT_PATH
lit_model = LITMODEL_CLASS.load_from_checkpoint(checkpoint_path=pth, args=args, model=model, strict=False)
lit_model.eval()
return lit_model
def save_model_to_torchscript(model, directory):
scripted_model = model.to_torchscript(method="script", file_path=None)
path = Path(directory) / STAGED_MODEL_FILENAME
torch.jit.save(scripted_model, path)
def upload_staged_model(staged_at, from_directory):
staged_at.add_file(Path(from_directory) / STAGED_MODEL_FILENAME)
wandb.log_artifact(staged_at)
def _find_artifact_run(entity, project, type, run, alias):
run_name = f"{entity}/{project}/{run}"
api_run = api.run(run_name)
artifacts = api_run.logged_artifacts()
match = [art for art in artifacts if alias in art.aliases and art.type == type]
if not match:
raise ValueError(f"No artifact with alias {alias} found at {run_name} of type {type}")
if len(match) > 1:
raise ValueError(f"Multiple artifacts ({len(match)}) with alias {alias} found at {run_name} of type {type}")
return f"{entity}/{project}/{match[0].name}"
def _find_artifact_project(entity, project, type, alias):
project_name = f"{entity}/{project}"
api_project = api.project(project, entity=entity)
api_artifact_types = api_project.artifacts_types()
# loop through all artifact types in this project
for artifact_type in api_artifact_types:
if artifact_type.name != type:
continue # skipping those that don't match type
collections = artifact_type.collections()
# loop through all artifacts and their versions
for collection in collections:
versions = collection.versions()
for version in versions:
if alias in version.aliases: # looking for the first one that matches the alias
return f"{project_name}/{version.name}"
raise ValueError(f"Artifact with alias {alias} not found in type {type} in {project_name}")
raise ValueError(f"Artifact type {type} not found. {project_name} could be private or not exist.")
def _get_entity_from(args):
entity = args.entity
if entity is None:
raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.")
elif entity == "DEFAULT":
entity = DEFAULT_ENTITY
return entity
def _setup_parser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--fetch",
action="store_true",
help=f"If provided, check ENTITY/FROM_PROJECT for an artifact with the provided STAGED_MODEL_NAME and download its latest version to {PROD_STAGING_ROOT}/STAGED_MODEL_NAME.",
)
parser.add_argument(
"--entity",
type=str,
default=None,
help=f"Entity from which to download the checkpoint. Note that checkpoints are always uploaded to the logged-in wandb entity. Pass the value 'DEFAULT' to also download from default entity, which is currently {DEFAULT_ENTITY}.",
)
parser.add_argument(
"--from_project",
type=str,
default=DEFAULT_FROM_PROJECT,
help=f"Project from which to download the checkpoint. Default is {DEFAULT_FROM_PROJECT}",
)
parser.add_argument(
"--to_project",
type=str,
default=DEFAULT_TO_PROJECT,
help=f"Project to which to upload the compiled model. Default is {DEFAULT_TO_PROJECT}.",
)
parser.add_argument(
"--run",
type=str,
default=None,
help=f"Optionally, the name of a run to check for an artifact of type {MODEL_CHECKPOINT_TYPE} that has the provided CKPT_ALIAS. Default is None.",
)
parser.add_argument(
"--ckpt_alias",
type=str,
default=BEST_CHECKPOINT_ALIAS,
help=f"Alias that identifies which model checkpoint should be staged.The artifact's alias can be set manually or programmatically elsewhere. Default is {BEST_CHECKPOINT_ALIAS!r}.",
)
parser.add_argument(
"--staged_model_name",
type=str,
default=DEFAULT_STAGED_MODEL_NAME,
help=f"Name to give the staged model artifact. Default is {DEFAULT_STAGED_MODEL_NAME!r}.",
)
return parser
if __name__ == "__main__":
parser = _setup_parser()
args = parser.parse_args()
main(args)
================================================
FILE: lab07/training/tests/test_memorize_iam.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
# tests whether we can achieve a criterion loss
# on a single batch within a certain number of epochs
FAILURE=false
# constants and CLI args set by aiming for <5 min test on commodity GPU,
# including data download step
MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash
CRITERION="${2:-1.0}"
# train on GPU if it's available
GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))')
python ./training/run_experiment.py \
--data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \
--augment_data false --tf_dropout 0.0 \
--gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \
--log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true
python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Memorization test failed at loss criterion $CRITERION"
exit 1
fi
echo "Memorization test passed at loss criterion $CRITERION"
exit 0
================================================
FILE: lab07/training/tests/test_model_development.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
CI="${CI:-false}"
if [ "$CI" = false ]; then
export WANDB_PROJECT="fsdl-testing-2022"
else
export WANDB_PROJECT="fsdl-testing-2022-ci"
fi
echo "training smaller version of real model class on real data"
python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \
--limit_train_batches 1 --limit_val_batches 1 --limit_test_batches 1 --num_sanity_val_steps 0 \
--num_workers 1 --wandb || FAILURE=true
TRAIN_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//")
echo "staging trained model from run $TRAIN_RUN"
python training/stage_model.py --entity DEFAULT --run "$TRAIN_RUN" --staged_model_name test-dummy --ckpt_alias latest --to_project "$WANDB_PROJECT" --from_project "$WANDB_PROJECT" || FAILURE=true
echo "fetching staged model"
python training/stage_model.py --entity DEFAULT --fetch --from_project $WANDB_PROJECT --staged_model_name test-dummy || FAILURE=true
STAGE_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//")
if [ "$FAILURE" = true ]; then
echo "Model development test failed"
echo "cleaning up local files"
rm -rf text_recognizer/artifacts/test-dummy
echo "leaving remote files in place"
exit 1
fi
echo "cleaning up local and remote files"
rm -rf text_recognizer/artifacts/test-dummy
python training/cleanup_artifacts.py --entity DEFAULT --project "$WANDB_PROJECT" \
--run_ids "$TRAIN_RUN" "$STAGE_RUN" --all -v
# note: if $TRAIN_RUN and $STAGE_RUN are not set, this will fail.
# that's good because it avoids all artifacts from the project being deleted due to the --all.
echo "Model development test passed"
exit 0
================================================
FILE: lab07/training/tests/test_run_experiment.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
echo "running full loop test with CNN on fake data"
python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true
echo "running fast_dev_run test of real model class on real data"
python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \
--fast_dev_run --num_sanity_val_steps 0 \
--num_workers 1 || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Test for run_experiment.py failed"
exit 1
fi
echo "Tests for run_experiment.py passed"
exit 0
================================================
FILE: lab07/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: lab08/.flake8
================================================
[flake8]
select = ANN,B,B9,BLK,C,D,E,F,I,S,W
# only check selected error codes
max-complexity = 12
# C9 - flake8 McCabe Complexity checker -- threshold
max-line-length = 120
# E501 - flake8 -- line length too long, actually handled by black
extend-ignore =
# E W - flake8 PEP style check
E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks
# S - flake8-bandit safety check
S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password
# ANN - flake8-annotations type annotation check
ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some
# D1 - flake8-docstrings docstring style check
D100,D102,D103,D104,D105, # missing docstrings
# D2 D4 - flake8-docstrings docstring style check
D200,D205,D400,D401, # whitespace issues and first line content
# DAR - flake8-darglint docstring correctness check
DAR103, # mismatched or missing type in docstring
application-import-names = app_gradio,text_recognizer,tests,training
# flake8-import-order: which names are first party?
import-order-style = google
# flake8-import-order: which import order style guide do we use?
docstring-convention = numpy
# flake8-docstrings: which docstring style guide do we use?
strictness = short
# darglint: how "strict" are we with docstring completeness?
docstring-style = numpy
# darglint: which docstring style guide do we use?
suppress-none-returning = true
# flake8-annotations: do we allow un-annotated Nones in returns?
mypy-init-return = true
# flake8-annotations: do we allow init to have no return annotation?
per-file-ignores =
# list of case-by-case ignores, see files for details
*/__init__.py:F401,I
*/data/*.py:DAR
data/*.py:F,I
*text_recognizer/util.py:DAR101,F401
*training/run_experiment.py:I202
*app_gradio/app.py:I202
================================================
FILE: lab08/.github/workflows/pre-commit.yml
================================================
name: pre-commit
on:
pull_request:
push:
# allows this Action to be triggered manually
workflow_dispatch:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.10'
- uses: pre-commit/action@v3.0.0
================================================
FILE: lab08/.pre-commit-config.yaml
================================================
repos:
# a set of useful Python-based pre-commit hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
hooks:
# list of definitions and supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace # removes any whitespace at the ends of lines
- id: check-toml # check toml syntax by loading all toml files
- id: check-yaml # check yaml syntax by loading all yaml files
- id: check-json # check-json syntax by loading all json files
- id: check-merge-conflict # check for files with merge conflict strings
args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge
- id: check-added-large-files # check that no "large" files have been added
args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server
- id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.)
- id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.)
# black python autoformatting
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
# additional configuration of black in pyproject.toml
# flake8 python linter with all the fixins
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08)
additional_dependencies: [
flake8-bandit, flake8-bugbear, flake8-docstrings,
flake8-import-order, darglint, mypy, pycodestyle, pydocstyle]
args: ["--config", ".flake8"]
# additional configuration of flake8 and extensions in .flake8
# shellcheck-py for linting shell files
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.8.0.4
hooks:
- id: shellcheck
================================================
FILE: lab08/api_serverless/Dockerfile
================================================
# Starting from an official AWS image
# Keep any dependencies and versions in this file aligned with the environment.yml and Makefile
FROM public.ecr.aws/lambda/python:3.10
# Install Python dependencies
COPY requirements/prod.txt ./requirements.txt
RUN pip install --upgrade pip==23.1.2
RUN pip install -r requirements.txt
# Copy only the relevant directories and files
# note that we use a .dockerignore file to avoid copying logs etc.
COPY text_recognizer/ ./text_recognizer
COPY api_serverless/api.py ./api.py
CMD ["api.handler"]
================================================
FILE: lab08/api_serverless/__init__.py
================================================
"""Cloud function-backed API for paragraph recognition."""
================================================
FILE: lab08/api_serverless/api.py
================================================
"""AWS Lambda function serving text_recognizer predictions."""
import json
from PIL import ImageStat
from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer
import text_recognizer.util as util
model = ParagraphTextRecognizer()
def handler(event, _context):
"""Provide main prediction API."""
print("INFO loading image")
image = _load_image(event)
if image is None:
return {"statusCode": 400, "message": "neither image_url nor image found in event"}
print("INFO image loaded")
print("INFO starting inference")
pred = model.predict(image)
print("INFO inference complete")
image_stat = ImageStat.Stat(image)
print("METRIC image_mean_intensity {}".format(image_stat.mean[0]))
print("METRIC image_area {}".format(image.size[0] * image.size[1]))
print("METRIC pred_length {}".format(len(pred)))
print("INFO pred {}".format(pred))
return {"pred": str(pred)}
def _load_image(event):
event = _from_string(event)
event = _from_string(event.get("body", event))
image_url = event.get("image_url")
if image_url is not None:
print("INFO url {}".format(image_url))
return util.read_image_pil(image_url, grayscale=True)
else:
image = event.get("image")
if image is not None:
print("INFO reading image from event")
return util.read_b64_image(image, grayscale=True)
else:
return None
def _from_string(event):
if isinstance(event, str):
return json.loads(event)
else:
return event
================================================
FILE: lab08/app_gradio/Dockerfile
================================================
# The "buster" flavor of the official docker Python image is based on Debian and includes common packages.
# Keep any dependencies and versions in this file aligned with the environment.yml and Makefile
FROM python:3.10-buster
# Create the working directory
# set -x prints commands and set -e causes us to stop on errors
RUN set -ex && mkdir /repo
WORKDIR /repo
# Install Python dependencies
COPY requirements/prod.txt ./requirements.txt
RUN pip install --upgrade pip==23.1.2
RUN pip install -r requirements.txt
ENV PYTHONPATH ".:"
# Copy only the relevant directories
# note that we use a .dockerignore file to avoid copying logs etc.
COPY text_recognizer/ ./text_recognizer
COPY app_gradio/ ./app_gradio
# Use docker run -it --rm -p $PORT:11717 to run the web server and listen on host $PORT
# add --help to see help for the Python script
ENTRYPOINT ["python3", "app_gradio/app.py", "--port", "11717"]
================================================
FILE: lab08/app_gradio/README.md
================================================
## Full-Paragraph Optical Character Recognition
For more on how this application works,
[check out the GitHub repo](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022).
### Flagging
If the model outputs in the top-right are wrong in some way,
let us know by clicking the "flagging" buttons underneath.
We'll analyze the results with
[Gantry](https://gantry.io/blog/introducing-gantry/)
and use them to improve the model!
================================================
FILE: lab08/app_gradio/__init__.py
================================================
================================================
FILE: lab08/app_gradio/app.py
================================================
"""Provide an image of handwritten text and get back out a string!"""
import argparse
import json
import logging
import os
from pathlib import Path
from typing import Callable
import warnings
import gradio as gr
from PIL import ImageStat
from PIL.Image import Image
import requests
from app_gradio.flagging import GantryImageToTextLogger, get_api_key
from app_gradio.s3_util import make_unique_bucket_name
from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer
import text_recognizer.util as util
os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU
logging.basicConfig(level=logging.INFO)
DEFAULT_APPLICATION_NAME = "fsdl-text-recognizer"
APP_DIR = Path(__file__).resolve().parent # what is the directory for this application?
FAVICON = APP_DIR / "1f95e.png" # path to a small image for display in browser tab and social media
README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown
DEFAULT_PORT = 11700
def main(args):
predictor = PredictorBackend(url=args.model_url)
frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application)
frontend.launch(
server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104
server_port=args.port, # set a port to bind to, failing if unavailable
share=True, # should we create a (temporary) public link on https://gradio.app?
favicon_path=FAVICON, # what icon should we display in the address bar?
)
def make_frontend(
fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fsdl-text-recognizer"
):
"""Creates a gradio.Interface frontend for an image to text function."""
examples_dir = Path("text_recognizer") / "tests" / "support" / "paragraphs"
example_fnames = [elem for elem in os.listdir(examples_dir) if elem.endswith(".png")]
example_paths = [examples_dir / fname for fname in example_fnames]
examples = [[str(path)] for path in example_paths]
allow_flagging = "never"
if flagging:
allow_flagging = "manual"
api_key = get_api_key()
if gantry and api_key: # if we're logging user feedback to Gantry and we have an API key
allow_flagging = "manual" # turn on Gradio flagging features
# callback for logging input images, output text, and feedback to Gantry
flagging_callback = GantryImageToTextLogger(application=app_name, api_key=api_key)
# that sends images to S3
flagging_dir = make_unique_bucket_name(prefix=app_name, seed=api_key)
else: # otherwise, log to a local CSV file
if gantry and api_key is None:
warnings.warn("No Gantry API key found, logging to local directory instead.", stacklevel=1)
flagging_callback = gr.CSVLogger()
flagging_dir = "flagged"
else:
flagging_callback, flagging_dir = None, None
readme = _load_readme(with_logging=allow_flagging == "manual")
# build a basic browser interface to a Python function
frontend = gr.Interface(
fn=fn, # which Python function are we interacting with?
outputs=gr.components.Textbox(), # what output widgets does it need? the default text widget
# what input widgets does it need? we configure an image widget
inputs=gr.components.Image(type="pil", label="Handwritten Text"),
title="📝 Text Recognizer", # what should we display at the top of the page?
thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media?
description=__doc__, # what should we display just above the interface?
article=readme, # what long-form content should we display below the interface?
examples=examples, # which potential inputs should we provide?
cache_examples=False, # should we cache those inputs for faster inference? slows down start
allow_flagging=allow_flagging, # should we show users the option to "flag" outputs?
flagging_options=["incorrect", "offensive", "other"], # what options do users have for feedback?
flagging_callback=flagging_callback,
flagging_dir=flagging_dir,
)
return frontend
class PredictorBackend:
"""Interface to a backend that serves predictions.
To communicate with a backend accessible via a URL, provide the url kwarg.
Otherwise, runs a predictor locally.
"""
def __init__(self, url=None):
if url is not None:
self.url = url
self._predict = self._predict_from_endpoint
else:
model = ParagraphTextRecognizer()
self._predict = model.predict
def run(self, image):
pred, metrics = self._predict_with_metrics(image)
self._log_inference(pred, metrics)
return pred
def _predict_with_metrics(self, image):
pred = self._predict(image)
stats = ImageStat.Stat(image)
metrics = {
"image_mean_intensity": stats.mean,
"image_median": stats.median,
"image_extrema": stats.extrema,
"image_area": image.size[0] * image.size[1],
"pred_length": len(pred),
}
return pred, metrics
def _predict_from_endpoint(self, image):
"""Send an image to an endpoint that accepts JSON and return the predicted text.
The endpoint should expect a base64 representation of the image, encoded as a string,
under the key "image". It should return the predicted text under the key "pred".
Parameters
----------
image
A PIL image of handwritten text to be converted into a string.
Returns
-------
pred
A string containing the predictor's guess of the text in the image.
"""
encoded_image = util.encode_b64_image(image)
headers = {"Content-type": "application/json"}
payload = json.dumps({"image": "data:image/png;base64," + encoded_image})
response = requests.post(self.url, data=payload, headers=headers)
pred = response.json()["pred"]
return pred
def _log_inference(self, pred, metrics):
for key, value in metrics.items():
logging.info(f"METRIC {key} {value}")
logging.info(f"PRED >begin\n{pred}\nPRED >end")
def _load_readme(with_logging=False):
with open(README) as f:
lines = f.readlines()
if not with_logging:
lines = lines[: lines.index("\n")]
readme = "".join(lines)
return readme
def _make_parser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model_url",
default=None,
type=str,
help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.",
)
parser.add_argument(
"--port",
default=DEFAULT_PORT,
type=int,
help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.",
)
parser.add_argument(
"--flagging",
action="store_true",
help="Pass this flag to allow users to 'flag' model behavior and provide feedback.",
)
parser.add_argument(
"--gantry",
action="store_true",
help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.",
)
parser.add_argument(
"--application",
default=DEFAULT_APPLICATION_NAME,
type=str,
help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.",
)
return parser
if __name__ == "__main__":
parser = _make_parser()
args = parser.parse_args()
main(args)
================================================
FILE: lab08/app_gradio/flagging.py
================================================
import os
from typing import List, Optional, Union
import gantry
import gradio as gr
from gradio.components import Component
from smart_open import open
from app_gradio import s3_util
from text_recognizer.util import read_b64_string
class GantryImageToTextLogger(gr.FlaggingCallback):
"""A FlaggingCallback that logs flagged image-to-text data to Gantry via S3."""
def __init__(self, application: str, version: Union[int, str, None] = None, api_key: Optional[str] = None):
"""Logs image-to-text data that was flagged in Gradio to Gantry.
Images are logged to Amazon Web Services' Simple Storage Service (S3).
The flagging_dir provided to the Gradio interface is used to set the
name of the bucket on S3 into which images are logged.
See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK
for Python, boto3: https://realpython.com/python-boto3-aws-s3/
See https://gradio.app/docs/#flagging for details on how
flagging data is handled by Gradio.
See https://docs.gantry.io for information about logging data to Gantry.
Parameters
----------
application
The name of the application on Gantry to which flagged data should be uploaded.
Gantry validates and monitors data per application.
version
The schema version to use during validation by Gantry. If not provided, Gantry
will use the latest version. A new version will be created if the provided version
does not exist yet.
api_key
Optionally, provide your Gantry API key here. Provided for convenience
when testing and developing locally or in notebooks. The API key can
alternatively be provided via the GANTRY_API_KEY environment variable.
"""
self.application = application
self.version = version
gantry.init(api_key=api_key)
def setup(self, components: List[Component], flagging_dir: str):
"""Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket."""
self._counter = 0
self.bucket = s3_util.get_or_create_bucket(flagging_dir)
s3_util.enable_bucket_versioning(self.bucket)
s3_util.add_access_policy(self.bucket)
self.image_component_idx, self.text_component_idx = self._find_image_and_text_components(components)
def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
"""Sends flagged outputs and feedback to Gantry and image inputs to S3."""
image = flag_data[self.image_component_idx]
text = flag_data[self.text_component_idx]
feedback = {"flag": flag_option}
if username is not None:
feedback["user"] = username
data_type, image_buffer = read_b64_string(image, return_data_type=True)
image_url = self._to_s3(image_buffer.read(), filetype=data_type)
self._to_gantry(image_url, text, feedback)
self._counter += 1
return self._counter
def _to_gantry(self, input_image_url, output_text, feedback):
inputs = {"image": input_image_url}
outputs = {"output_text": output_text}
gantry.log_record(self.application, self.version, inputs=inputs, outputs=outputs, feedback=feedback)
def _to_s3(self, image_bytes, key=None, filetype=None):
if key is None:
key = s3_util.make_key(image_bytes, filetype=filetype)
s3_uri = s3_util.get_uri_of(self.bucket, key)
with open(s3_uri, "wb") as s3_object:
s3_object.write(image_bytes)
return s3_uri
def _find_image_and_text_components(self, components: List[Component]):
image_component_idx, text_component_idx = None, None
for idx, component in enumerate(components):
if isinstance(component, (gr.inputs.Image, gr.components.Image)):
image_component_idx = idx
elif isinstance(component, (gr.templates.Text, gr.components.Textbox)):
text_component_idx = idx
if image_component_idx is None:
raise RuntimeError(f"No image input found in gradio interface with components {components}")
elif text_component_idx is None:
raise RuntimeError(f"No text output found in gradio interface with components {components}")
return image_component_idx, text_component_idx
def get_api_key() -> Optional[str]:
"""Convenience method for fetching the Gantry API key."""
api_key = os.environ.get("GANTRY_API_KEY")
return api_key
================================================
FILE: lab08/app_gradio/s3_util.py
================================================
import hashlib
import json
import boto3
import botocore
S3_URL_FORMAT = "https://{bucket}.s3.{region}.amazonaws.com/{key}"
S3_URI_FORMAT = "s3://{bucket}/{key}"
s3 = boto3.resource("s3")
def get_or_create_bucket(name):
"""Gets an S3 bucket with boto3 or creates it if it doesn't exist."""
try: # try to create a bucket
name, response = _create_bucket(name)
except botocore.exceptions.ClientError as err:
# error handling from https://github.com/boto/boto3/issues/1195#issuecomment-495842252
status = err.response["ResponseMetadata"]["HTTPStatusCode"] # status codes identify particular errors
if status == 409: # if the bucket exists already,
pass # we don't need to make it -- we presume we have the right permissions
else:
raise err
bucket = s3.Bucket(name)
return bucket
def _create_bucket(name):
"""Creates a bucket with the provided name."""
session = boto3.session.Session() # sessions hold on to credentials and config
current_region = session.region_name # so we can pull the default region
bucket_config = {"LocationConstraint": current_region} # and apply it to the bucket
bucket_response = s3.create_bucket(Bucket=name, CreateBucketConfiguration=bucket_config)
return name, bucket_response
def make_key(fileobj, filetype=None):
"""Creates a unique key for the fileobj and optionally append the filetype."""
identifier = make_identifier(fileobj)
if filetype is None:
return identifier
else:
return identifier + "." + filetype
def make_unique_bucket_name(prefix, seed):
"""Creates a unique bucket name from a prefix and a seed."""
name = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:10]
return prefix + "-" + name
def get_url_of(bucket, key=None):
"""Returns the url of a bucket and optionally of an object in that bucket."""
if not isinstance(bucket, str):
bucket = bucket.name
region = _get_region(bucket)
key = key or ""
url = _format_url(bucket, region, key)
return url
def get_uri_of(bucket, key=None):
"""Returns the s3:// uri of a bucket and optionally of an object in that bucket."""
if not isinstance(bucket, str):
bucket = bucket.name
key = key or ""
uri = _format_uri(bucket, key)
return uri
def enable_bucket_versioning(bucket):
"""Turns on versioning for bucket contents, which avoids deletion."""
if not isinstance(bucket, str):
bucket = bucket.name
bucket_versioning = s3.BucketVersioning(bucket)
return bucket_versioning.enable()
def add_access_policy(bucket):
"""Adds a policy to our bucket that allows the Gantry app to access data."""
access_policy = json.dumps(_get_policy(bucket.name))
s3.meta.client.put_bucket_policy(Bucket=bucket.name, Policy=access_policy)
def _get_policy(bucket_name):
"""Returns a bucket policy allowing Gantry app access as a JSON-compatible dictionary."""
return {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"AWS": [
"arn:aws:iam::848836713690:root",
"arn:aws:iam::339325199688:root",
"arn:aws:iam::665957668247:root",
]
},
"Action": ["s3:GetObject", "s3:GetObjectVersion"],
"Resource": f"arn:aws:s3:::{bucket_name}/*",
},
{
"Effect": "Allow",
"Principal": {
"AWS": [
"arn:aws:iam::848836713690:root",
"arn:aws:iam::339325199688:root",
"arn:aws:iam::665957668247:root",
]
},
"Action": "s3:ListBucketVersions",
"Resource": f"arn:aws:s3:::{bucket_name}",
},
],
}
def make_identifier(byte_data):
"""Create a unique identifier for a collection of bytes via hashing."""
# feed them to hashing algo -- security is not critical here, so we use SHA-1
hashed_data = hashlib.sha1(byte_data) # noqa: S3
identifier = hashed_data.hexdigest() # turn it into hexdecimal
return identifier
def _get_region(bucket):
"""Determine the region of an s3 bucket."""
if not isinstance(bucket, str):
bucket = bucket.name
s3_client = boto3.client("s3")
bucket_location_response = s3_client.get_bucket_location(Bucket=bucket)
bucket_location = bucket_location_response["LocationConstraint"]
return bucket_location
def _format_url(bucket_name, region, key=None):
key = key or ""
url = S3_URL_FORMAT.format(bucket=bucket_name, region=region, key=key)
return url
def _format_uri(bucket_name, key=None):
key = key or ""
uri = S3_URI_FORMAT.format(bucket=bucket_name, key=key)
return uri
================================================
FILE: lab08/app_gradio/tests/test_app.py
================================================
import json
import os
import requests
from app_gradio import app
from text_recognizer import util
os.environ["CUDA_VISIBLE_DEVICES"] = ""
TEST_IMAGE = "text_recognizer/tests/support/paragraphs/a01-077.png"
def test_local_run():
"""A quick test to make sure we can build the app and ping the API locally."""
backend = app.PredictorBackend()
frontend = app.make_frontend(fn=backend.run)
# run the UI without blocking
frontend.launch(share=False, prevent_thread_lock=True)
local_url = frontend.local_url
get_response = requests.get(local_url)
assert get_response.status_code == 200, get_response.content
image_b64 = util.encode_b64_image(util.read_image_pil(TEST_IMAGE))
local_api = f"{local_url}api/predict"
headers = {"Content-Type": "application/json"}
payload = json.dumps({"data": ["data:image/png;base64," + image_b64]})
post_response = requests.post(local_api, data=payload, headers=headers)
assert post_response.status_code == 200, post_response.content
================================================
FILE: lab08/notebooks/lab01_pytorch.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" `.\n",
"\n",
"A model that always predicts ` ` can achieve around 50% accuracy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EE-T7zgDgo7-"
},
"outputs": [],
"source": [
"padding_token = emnist_lines.emnist.inverse_mapping[\" \"]\n",
"torch.sum(line_ys == padding_token) / line_ys.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rGHWmOyVh5rV"
},
"source": [
"There are ways to adjust your classification metrics to\n",
"[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n",
"In general it's good to find a metric\n",
"that has baseline performance at 0 and perfect performance at 1,\n",
"so that numbers are clearly interpretable.\n",
"\n",
"But it's an important reminder to actually look\n",
"at your model's behavior from time to time.\n",
"Metrics are single numbers,\n",
"so they by necessity throw away a ton of information\n",
"about your model's behavior,\n",
"some of which is deeply relevant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6p--KWZ9YJWQ"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQnoOK8YLDv"
},
"source": [
"### 🌟 Research a `pl.Trainer` argument and try it out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7j652MtkYR8n"
},
"source": [
"The Lightning `Trainer` class is highly configurable\n",
"and has accumulated a number of features as Lightning has matured.\n",
"\n",
"Check out the documentation for this class\n",
"and pick an argument to try out with `training/run_experiment.py`.\n",
"Look for edge cases in its behavior,\n",
"especially when combined with other arguments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8UWNicq_jS7k"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"pl_version = pl.__version__\n",
"\n",
"print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n",
"print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n",
"\n",
"pl.Trainer??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "14AOfjqqYOoT"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "lab02b_cnn.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab08/notebooks/lab03_transformers.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" \", \"\")\n",
"\n",
"idx = random.randint(0, len(xs))\n",
"\n",
"print(show(ys[idx]))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4dT3UCNzTsoc"
},
"source": [
"The `ResnetTransformer` model can run on this data\n",
"if passed the `.config`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WXL-vIGRr86D"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MMxa-oWyT01E"
},
"source": [
"Our models are now big enough\n",
"that we want to make use of GPU acceleration\n",
"as much as we can,\n",
"even when working on single inputs,\n",
"so let's cast to the GPU if we have one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-YyUM8LgvW0w"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"rnt.to(device); xs = xs.to(device); ys = ys.to(device);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y-E3UdD4zUJi"
},
"source": [
"First, let's just pass it through the ResNet encoder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-LUUtlvaxrvg"
},
"outputs": [],
"source": [
"resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n",
" # resnet is designed for RGB images, so we replicate the input across channels 3 times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eimgJ5dnywjg"
},
"outputs": [],
"source": [
"resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n",
"plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n",
"plt.axis(\"off\"); plt.colorbar(fraction=0.05);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These embeddings, though generated by random, untrained weights,\n",
"are not entirely useless.\n",
"\n",
"Before neural networks could be effectively\n",
"trained end to end,\n",
"they were often used with frozen random weights\n",
"eveywhere except the final layer\n",
"(see e.g.\n",
"[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n",
"[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n",
"these methods were still competitive, and\n",
"[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n",
"provide a\n",
"[theoretical basis](https://arxiv.org/abs/2011.14522)\n",
"for understanding their performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ye6pW0ETzw2A"
},
"source": [
"The final result, though, is repetitive gibberish --\n",
"at the bare minimum, we need to train the unembedding/readout layer\n",
"in order to get reasonable text."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our architecture includes randomization with dropout,\n",
"so repeated runs of the cell below will generate different outcomes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu3Pa7gLsFMo"
},
"outputs": [],
"source": [
"preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gvCXUbskv6XM"
},
"outputs": [],
"source": [
"print(show(preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Without teacher forcing, runtime is also variable from iteration to iteration --\n",
"the model stops when it generates an \"end sequence\" or padding token,\n",
"which is not deterministic thanks to the dropout layers.\n",
"For similar reasons, runtime is variable across inputs.\n",
"\n",
"The variable runtime of autoregressive generation\n",
"is also not great for scaling.\n",
"In a distributed setting, as required for large scale,\n",
"forward passes need to be synced across devices,\n",
"and if one device is generating a batch of much longer sequences,\n",
"it will cause all the others to idle while they wait on it to finish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t76MSVRXV0V7"
},
"source": [
"Let's turn our model into a `TransformerLitModel`\n",
"so we can run with teacher forcing.\n",
"\n",
"> You may be wondering:\n",
" why isn't teacher forcing part of the PyTorch module?\n",
" In general, the `LightningModule`\n",
" should encapsulate things that are needed in training, validation, and testing\n",
" but not during inference.\n",
" The teacher forcing trick fits this paradigm,\n",
" even though it's so critical to what makes Transformers powerful. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8qrHRKHowdDi"
},
"outputs": [],
"source": [
"import text_recognizer.lit_models\n",
"\n",
"lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MlNaFqR50Oid"
},
"source": [
"Now we can use `.teacher_forward` if we also provide the target `ys`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lpZdqXS5wn0F"
},
"outputs": [],
"source": [
"forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Zx9SmsN0QLT"
},
"source": [
"This may not run faster than the `rnt.forward`,\n",
"since generations are always the maximum possible length,\n",
"but runtimes and output lengths are deterministic and constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tu-XNYpi0Qvi"
},
"source": [
"Forcing doesn't necessarily make our predictions better.\n",
"They remain highly repetitive gibberish."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JcEgify9w0sv"
},
"outputs": [],
"source": [
"forcing_preds = torch.argmax(forcing_outs, dim=0)\n",
"\n",
"print(show(forcing_preds.cpu()))\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xn6GGNzc9a3o"
},
"source": [
"## Training the `ResNetTransformer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uvZYsuSyWUXe"
},
"source": [
"We're finally ready to train this model on full paragraphs of handwritten text!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cJwC7b720Sd"
},
"source": [
"This is a more serious model --\n",
"it's the one we use in the\n",
"[deployed TextRecognizer application](http://fsdl.me/app).\n",
"It's much larger than the models we've seen this far,\n",
"so it can easily outstrip available compute resources,\n",
"in particular GPU memory.\n",
"\n",
"To help, we use\n",
"[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n",
"which shrinks the size of most of our floats by half,\n",
"which reduces memory consumption and can speed up computation.\n",
"\n",
"If your GPU has less than 8GB of available RAM,\n",
"you'll see a \"CUDA out of memory\" `RuntimeError`,\n",
"which is something of a\n",
"[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n",
"In this case, you can resolve it by reducing the `--batch_size`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w1mXlhfy04Nm"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"gpus = int(torch.cuda.is_available())\n",
"\n",
"if gpus:\n",
" !nvidia-smi\n",
"else:\n",
" print(\"watch out! working with this model on a typical CPU is not feasible\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "os1vW1rPZ1dy"
},
"source": [
"Even with an okay GPU, like a\n",
"[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n",
"a single epoch of training can take over 10 minutes to run.\n",
"We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n",
"but you can remove those flags to see what full training looks like."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vnF6dWFn4JlZ"
},
"source": [
"It can take a long time (overnight)\n",
"to train this model to decent performance on a single GPU,\n",
"so we'll focus on other pieces for the exercises.\n",
"\n",
"> At the time of writing in mid-2022, the cheapest readily available option\n",
"for training this model to decent performance on this dataset with this codebase\n",
"comes out around $10, using\n",
"[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n",
"See, for example,\n",
"[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n",
"and associated experiment.\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HufjdUZN0t4l",
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"# above %%magic times the cell, useful as a poor man's profiler\n",
"\n",
"%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n",
" --gpus={gpus} --batch_size 16 --precision 16 \\\n",
" --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L6fQ93ju3Iku"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udb1Ekjx3L63"
},
"source": [
"### 🌟 Try out gradient accumulation and other \"training tricks\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpqViB4p3Wfb"
},
"source": [
"Larger batches are helpful not only for increasing parallelization\n",
"and amortizing fixed costs\n",
"but also for getting more reliable gradients.\n",
"Larger batches give gradients with less noise\n",
"and to a point, less gradient noise means faster convergence.\n",
"\n",
"But larger batches result in larger tensors,\n",
"which take up more GPU memory,\n",
"a resource that is tightly constrained\n",
"and device-dependent.\n",
"\n",
"Does that mean we are limited in the quality of our gradients\n",
"due to our machine size?\n",
"\n",
"Not entirely:\n",
"look up the `--accumulate_grad_batches`\n",
"argument to the `pl.Trainer`.\n",
"You should be able to understand why\n",
"it makes it possible to compute the same gradients\n",
"you would find for a batch of size `k * N`\n",
"on a machine that can only run batches up to size `N`.\n",
"\n",
"Accumulating gradients across batches is among the\n",
"[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n",
"Try some of them out!\n",
"Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2vtkmX830y3"
},
"source": [
"### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n",
"\n",
"While training this model to actually fit the whole dataset is infeasible\n",
"as a short exercise on commodity hardware,\n",
"it's practical to train this model to memorize a batch of 16 examples.\n",
"\n",
"Passing `--overfit_batches 1` flag limits the number of training batches to 1\n",
"and turns off\n",
"[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n",
"so that in each epoch, the model just sees the same single batch of data over and over again.\n",
"\n",
"At first, try training the model to a loss of `2.5` --\n",
"it should be doable in 100 epochs or less,\n",
"which is just a few minutes on a commodity GPU.\n",
"\n",
"Once you've got that working,\n",
"crank up the number of epochs by a factor of 10\n",
"and confirm that the loss continues to go down.\n",
"\n",
"Some tips:\n",
"\n",
"- Use `--limit_test_batches 0` to turn off testing.\n",
"We don't need it because we don't care about generalization\n",
"and it's relatively slow because it runs the model autoregressively.\n",
"\n",
"- Use `--help` and look through the model class args\n",
"to find the arguments used to reduce model size.\n",
"\n",
"- By default, there's lots of regularization to prevent overfitting.\n",
"Look through the args for the model class and data class\n",
"for regularization knobs to turn off or down."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab03_transformers.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab08/notebooks/lab04_experiments.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
" ", *characters, *iam_characters]
if __name__ == "__main__":
load_and_print_info(EMNIST)
================================================
FILE: lab08/text_recognizer/data/emnist_essentials.json
================================================
{"characters": ["", " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
================================================
FILE: lab08/text_recognizer/data/emnist_lines.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, Sequence
import h5py
import numpy as np
import torch
from text_recognizer.data import EMNIST
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.emnist_lines as metadata
from text_recognizer.stems.image import ImageStem
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
DEFAULT_MAX_LENGTH = 32
DEFAULT_MIN_OVERLAP = 0
DEFAULT_MAX_OVERLAP = 0.33
NUM_TRAIN = 10000
NUM_VAL = 2000
NUM_TEST = 2000
class EMNISTLines(BaseDataModule):
"""EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters."""
def __init__(
self,
args: argparse.Namespace = None,
):
super().__init__(args)
self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH)
self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP)
self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP)
self.num_train = self.args.get("num_train", NUM_TRAIN)
self.num_val = self.args.get("num_val", NUM_VAL)
self.num_test = self.args.get("num_test", NUM_TEST)
self.with_start_end_tokens = self.args.get("with_start_end_tokens", False)
self.mapping = metadata.MAPPING
self.output_dims = (self.max_length, 1)
max_width = metadata.CHAR_WIDTH * self.max_length
self.input_dims = (*metadata.DIMS[:2], max_width)
self.emnist = EMNIST()
self.transform = ImageStem()
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument(
"--max_length",
type=int,
default=DEFAULT_MAX_LENGTH,
help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}",
)
parser.add_argument(
"--min_overlap",
type=float,
default=DEFAULT_MIN_OVERLAP,
help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}",
)
parser.add_argument(
"--max_overlap",
type=float,
default=DEFAULT_MAX_OVERLAP,
help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}",
)
parser.add_argument("--with_start_end_tokens", action="store_true", default=False)
return parser
@property
def data_filename(self):
return (
PROCESSED_DATA_DIRNAME
/ f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5"
)
def prepare_data(self, *args, **kwargs) -> None:
if self.data_filename.exists():
return
np.random.seed(42)
self._generate_data("train")
self._generate_data("val")
self._generate_data("test")
def setup(self, stage: str = None) -> None:
print("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_train = f["x_train"][:]
y_train = f["y_train"][:].astype(int)
x_val = f["x_val"][:]
y_val = f["y_val"][:].astype(int)
self.data_train = BaseDataset(x_train, y_train, transform=self.transform)
self.data_val = BaseDataset(x_val, y_val, transform=self.transform)
if stage == "test" or stage is None:
with h5py.File(self.data_filename, "r") as f:
x_test = f["x_test"][:]
y_test = f["y_test"][:].astype(int)
self.data_test = BaseDataset(x_test, y_test, transform=self.transform)
def __repr__(self) -> str:
"""Print info about the dataset."""
basic = (
"EMNIST Lines Dataset\n"
f"Min overlap: {self.min_overlap}\n"
f"Max overlap: {self.max_overlap}\n"
f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.input_dims}\n"
f"Output dims: {self.output_dims}\n"
)
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n"
)
return basic + data
def _generate_data(self, split: str) -> None:
print(f"EMNISTLinesDataset generating data for {split}...")
from text_recognizer.data.sentence_generator import SentenceGenerator
sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens
emnist = self.emnist
emnist.prepare_data()
emnist.setup()
if split == "train":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_train
elif split == "val":
samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping)
num = self.num_val
else:
samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping)
num = self.num_test
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(self.data_filename, "a") as f:
x, y = create_dataset_of_images(
num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims
)
y = convert_strings_to_labels(
y,
emnist.inverse_mapping,
length=self.output_dims[0],
with_start_end_tokens=self.with_start_end_tokens,
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def get_samples_by_char(samples, labels, mapping):
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
samples_by_char[mapping[label]].append(sample)
return samples_by_char
def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)):
zero_image = torch.zeros(char_shape, dtype=torch.uint8)
sample_image_by_char = {}
for char in string:
if char in sample_image_by_char:
continue
samples = samples_by_char[char]
sample = samples[np.random.choice(len(samples))] if samples else zero_image
sample_image_by_char[char] = sample.reshape(*char_shape)
return [sample_image_by_char[char] for char in string]
def construct_image_from_string(
string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int
) -> torch.Tensor:
overlap = np.random.uniform(min_overlap, max_overlap)
sampled_images = select_letter_samples_for_string(string, samples_by_char)
H, W = sampled_images[0].shape
next_overlap_width = W - int(overlap * W)
concatenated_image = torch.zeros((H, width), dtype=torch.uint8)
x = 0
for image in sampled_images:
concatenated_image[:, x : (x + W)] += image
x += next_overlap_width
return torch.minimum(torch.Tensor([255]), concatenated_image)
def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims):
images = torch.zeros((N, dims[1], dims[2]))
labels = []
for n in range(N):
label = sentence_generator.generate()
images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1])
labels.append(label)
return images, labels
def convert_strings_to_labels(
strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool
) -> np.ndarray:
"""
Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with token.
"""
labels = np.ones((len(strings), length), dtype=np.uint8) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
if with_start_end_tokens:
tokens = [" token.
"""
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping[" "]
for i, string in enumerate(strings):
tokens = list(string)
tokens = [" "]
self.ignore_tokens = [self.start_index, self.end_index, self.padding_index]
self.val_cer = CharacterErrorRate(self.ignore_tokens)
self.test_cer = CharacterErrorRate(self.ignore_tokens)
================================================
FILE: lab08/text_recognizer/lit_models/metrics.py
================================================
"""Special-purpose metrics for tracking our model performance."""
from typing import Sequence
import torch
import torchmetrics
class CharacterErrorRate(torchmetrics.CharErrorRate):
"""Character error rate metric, allowing for tokens to be ignored."""
def __init__(self, ignore_tokens: Sequence[int], *args):
super().__init__(*args)
self.ignore_tokens = set(ignore_tokens)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore
preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()]
targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()]
super().update(preds_l, targets_l)
def test_character_error_rate():
metric = CharacterErrorRate([0, 1])
X = torch.tensor(
[
[0, 2, 2, 3, 3, 1], # error will be 0
[0, 2, 1, 1, 1, 1], # error will be .75
[0, 2, 2, 4, 4, 1], # error will be .5
]
)
Y = torch.tensor(
[
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
[0, 2, 2, 3, 3, 1],
]
)
metric(X, Y)
assert metric.compute() == sum([0, 0.75, 0.5]) / 3
if __name__ == "__main__":
test_character_error_rate()
================================================
FILE: lab08/text_recognizer/lit_models/transformer.py
================================================
"""An encoder-decoder Transformer model"""
from typing import List, Sequence
import torch
from .base import BaseImageToTextLitModel
from .util import replace_after
class TransformerLitModel(BaseImageToTextLitModel):
"""
Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module.
The module must implement an encode and decode method, and the forward method
should be the forward pass during production inference.
"""
def __init__(self, model, args=None):
super().__init__(model, args)
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index)
def forward(self, x):
return self.model(x)
def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Uses provided sequence y as guide for non-autoregressive encoding-decoding of x.
Parameters
----------
x
Batch of images to be encoded. See self.model.encode for shape information.
y
Batch of ground truth output sequences.
Returns
-------
torch.Tensor
(B, C, Sy) logits
"""
x = self.model.encode(x)
output = self.model.decode(x, y) # (Sy, B, C)
return output.permute(1, 2, 0) # (B, C, Sy)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("train/loss", loss)
outputs = {"loss": loss}
if self.is_logged_batch():
preds = self.get_preds(logits)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs})
return outputs
def validation_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("validation/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def test_step(self, batch, batch_idx):
x, y = batch
# compute loss as in training, for comparison
logits = self.teacher_forward(x, y[:, :-1])
loss = self.loss_fn(logits, y[:, 1:])
self.log("test/loss", loss, prog_bar=True, sync_dist=True)
outputs = {"loss": loss}
# compute predictions as in production, for comparison
preds = self(x)
self.val_cer(preds, y)
self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True)
pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y)
self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx)
self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx)
return outputs
def map(self, ks: Sequence[int], ignore: bool = True) -> str:
"""Maps an iterable of integers to a string using the lit model's mapping."""
if ignore:
return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens])
else:
return "".join([self.mapping[k] for k in ks])
def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]:
"""Maps a list of lists of integers to a list of strings using the lit model's mapping."""
return [self.map(k, ignore) for k in ks]
def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor:
"""Converts logit-like Tensors into prediction indices, optionally overwritten after end token index.
Parameters
----------
logitlikes
(B, C, Sy) Tensor with classes as second dimension. The largest value is the one
whose index we will return. Logits, logprobs, and probs are all acceptable.
replace_after_end
Whether to replace values after the first appearance of the end token with the padding token.
Returns
-------
torch.Tensor
(B, Sy) Tensor of integers in [0, C-1] representing predictions.
"""
raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy)
if replace_after_end:
return replace_after(raw, self.end_index, self.padding_index) # (B, Sy)
else:
return raw # (B, Sy)
================================================
FILE: lab08/text_recognizer/lit_models/util.py
================================================
from typing import Union
import torch
def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor:
"""Return indices of first appearance of element in x, collapsing along dim.
Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9
Parameters
----------
x
One or two-dimensional Tensor to search for element.
element
Item to search for inside x.
dim
Dimension of Tensor to collapse over.
Returns
-------
torch.Tensor
Indices where element occurs in x. If element is not found,
return length of x along dim. One dimension smaller than x.
Raises
------
ValueError
if x is not a 1 or 2 dimensional Tensor
Examples
--------
>>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)
tensor([2, 1, 3, 0])
>>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0)
tensor(0)
"""
if x.dim() > 2 or x.dim() == 0:
raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}")
matches = x == element
first_appearance_mask = (matches.cumsum(dim) == 1) & matches
does_match, match_index = first_appearance_mask.max(dim)
first_inds = torch.where(does_match, match_index, x.shape[dim])
return first_inds
def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor:
"""Replace all values in each row of 2d Tensor x after the first appearance of element with replace.
Parameters
----------
x
Two-dimensional Tensor (shape denoted (B, S)) to replace values in.
element
Item to search for inside x.
replace
Item that replaces entries that appear after element.
Returns
-------
outs
New Tensor of same shape as x with values after element replaced.
Examples
--------
>>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4)
tensor([[1, 2, 3],
[2, 3, 4],
[1, 1, 1],
[3, 4, 4]])
"""
first_appearances = first_appearance(x, element, dim=1) # (B,)
indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,)
outs = torch.where(
indices[None, :] <= first_appearances[:, None], # if index is before first appearance
x, # return the value from x
replace, # otherwise, return the replacement value
)
return outs # (B, S)
================================================
FILE: lab08/text_recognizer/metadata/emnist.py
================================================
from pathlib import Path
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist"
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist"
PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json"
NUM_SPECIAL_TOKENS = 4
INPUT_SHAPE = (28, 28)
DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor()
OUTPUT_DIMS = (1,)
MAPPING = [
"",
" ",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"H",
"I",
"J",
"K",
"L",
"M",
"N",
"O",
"P",
"Q",
"R",
"S",
"T",
"U",
"V",
"W",
"X",
"Y",
"Z",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
================================================
FILE: lab08/text_recognizer/metadata/emnist_lines.py
================================================
from pathlib import Path
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines"
ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json"
CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3]
DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length
MAPPING = emnist.MAPPING
================================================
FILE: lab08/text_recognizer/metadata/iam.py
================================================
import text_recognizer.metadata.shared as shared
RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam"
EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be
LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates
================================================
FILE: lab08/text_recognizer/metadata/iam_lines.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines"
IMAGE_SCALE_FACTOR = 2
CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate
IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR
IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (89, 1)
MAPPING = emnist.MAPPING
================================================
FILE: lab08/text_recognizer/metadata/iam_paragraphs.py
================================================
import text_recognizer.metadata.emnist as emnist
import text_recognizer.metadata.shared as shared
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs"
NEW_LINE_TOKEN = "\n"
MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN]
# must match IMAGE_SCALE_FACTOR for IAMLines to be compatible with synthetic paragraphs
IMAGE_SCALE_FACTOR = 2
IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640
IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH)
MAX_LABEL_LENGTH = 682
DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1)
================================================
FILE: lab08/text_recognizer/metadata/iam_synthetic_paragraphs.py
================================================
import text_recognizer.metadata.iam_paragraphs as iam_paragraphs
import text_recognizer.metadata.shared as shared
NEW_LINE_TOKEN = iam_paragraphs.NEW_LINE_TOKEN
PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs"
EXPECTED_BATCH_SIZE = 64
EXPECTED_GPUS = 8
EXPECTED_STEPS = 40
# set the dataset's length based on parameters during typical training
DATASET_LEN = EXPECTED_BATCH_SIZE * EXPECTED_GPUS * EXPECTED_STEPS
================================================
FILE: lab08/text_recognizer/metadata/mnist.py
================================================
"""Metadata for the MNIST dataset."""
import text_recognizer.metadata.shared as shared
DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME
DIMS = (1, 28, 28)
OUTPUT_DIMS = (1,)
MAPPING = list(range(10))
TRAIN_SIZE = 55000
VAL_SIZE = 5000
================================================
FILE: lab08/text_recognizer/metadata/shared.py
================================================
from pathlib import Path
DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded"
================================================
FILE: lab08/text_recognizer/models/__init__.py
================================================
"""Models for character and text recognition in images."""
from .mlp import MLP
from .cnn import CNN
from .line_cnn_simple import LineCNNSimple
from .resnet_transformer import ResnetTransformer
from .line_cnn_transformer import LineCNNTransformer
================================================
FILE: lab08/text_recognizer/models/cnn.py
================================================
"""Basic convolutional model building blocks."""
import argparse
from typing import Any, Dict
import torch
from torch import nn
import torch.nn.functional as F
CONV_DIM = 64
FC_DIM = 128
FC_DROPOUT = 0.25
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(self, input_channels: int, output_channels: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class CNN(nn.Module):
"""Simple CNN for recognizing characters in a square image."""
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_channels, input_height, input_width = self.data_config["input_dims"]
assert (
input_height == input_width
), f"input height and width should be equal, but was {input_height}, {input_width}"
self.input_height, self.input_width = input_height, input_width
num_classes = len(self.data_config["mapping"])
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.conv1 = ConvBlock(input_channels, conv_dim)
self.conv2 = ConvBlock(conv_dim, conv_dim)
self.dropout = nn.Dropout(fc_dropout)
self.max_pool = nn.MaxPool2d(2)
# Because our 3x3 convs have padding size 1, they leave the input size unchanged.
# The 2x2 max-pool divides the input size by 2.
conv_output_height, conv_output_width = input_height // 2, input_width // 2
self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
self.fc1 = nn.Linear(self.fc_input_dim, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the CNN to x.
Parameters
----------
x
(B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.
Returns
-------
torch.Tensor
(B, Cl) tensor
"""
_B, _Ch, H, W = x.shape
assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}"
x = self.conv1(x) # _B, CONV_DIM, H, W
x = self.conv2(x) # _B, CONV_DIM, H, W
x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2
x = self.dropout(x)
x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2
x = self.fc1(x) # _B, FC_DIM
x = F.relu(x)
x = self.fc2(x) # _B, Cl
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab08/text_recognizer/models/line_cnn.py
================================================
"""Basic building blocks for convolutional models over lines of text."""
import argparse
import math
from typing import Any, Dict, Tuple, Union
import torch
from torch import nn
import torch.nn.functional as F
# Common type hints
Param2D = Union[int, Tuple[int, int]]
CONV_DIM = 32
FC_DIM = 512
FC_DROPOUT = 0.2
WINDOW_WIDTH = 16
WINDOW_STRIDE = 8
class ConvBlock(nn.Module):
"""
Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU.
"""
def __init__(
self,
input_channels: int,
output_channels: int,
kernel_size: Param2D = 3,
stride: Param2D = 1,
padding: Param2D = 1,
) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the ConvBlock to x.
Parameters
----------
x
(B, C, H, W) tensor
Returns
-------
torch.Tensor
(B, C, H, W) tensor
"""
c = self.conv(x)
r = self.relu(c)
return r
class LineCNN(nn.Module):
"""
Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits
"""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.args = vars(args) if args is not None else {}
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
_C, H, _W = data_config["input_dims"]
conv_dim = self.args.get("conv_dim", CONV_DIM)
fc_dim = self.args.get("fc_dim", FC_DIM)
fc_dropout = self.args.get("fc_dropout", FC_DROPOUT)
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
# Input is (1, H, W)
self.convs = nn.Sequential(
ConvBlock(1, conv_dim),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim, stride=2),
ConvBlock(conv_dim, conv_dim),
ConvBlock(conv_dim, conv_dim * 2, stride=2),
ConvBlock(conv_dim * 2, conv_dim * 2),
ConvBlock(conv_dim * 2, conv_dim * 4, stride=2),
ConvBlock(conv_dim * 4, conv_dim * 4),
ConvBlock(
conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0
),
)
self.fc1 = nn.Linear(fc_dim, fc_dim)
self.dropout = nn.Dropout(fc_dropout)
self.fc2 = nn.Linear(fc_dim, self.num_classes)
self._init_weights()
def _init_weights(self):
"""
Initialize weights in a better way than default.
See https://github.com/pytorch/pytorch/issues/18182
"""
for m in self.modules():
if type(m) in {
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
}:
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(m.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the LineCNN to a black-and-white input image.
Parameters
----------
x
(B, 1, H, W) input image
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and self.window_width
C is self.num_classes
"""
_B, _C, _H, _W = x.shape
x = self.convs(x) # (B, FC_DIM, 1, Sx)
x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM)
x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM)
x = self.dropout(x)
x = self.fc2(x) # (B, S, C)
x = x.permute(0, 2, 1) # -> (B, C, S)
if self.limit_output_length:
x = x[:, :, : self.output_length]
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--conv_dim", type=int, default=CONV_DIM)
parser.add_argument("--fc_dim", type=int, default=FC_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab08/text_recognizer/models/line_cnn_simple.py
================================================
"""Simplest version of LineCNN that works on cleanly-separated characters."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .cnn import CNN
IMAGE_SIZE = 28
WINDOW_WIDTH = IMAGE_SIZE
WINDOW_STRIDE = IMAGE_SIZE
class LineCNNSimple(nn.Module):
"""LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
self.WW = self.args.get("window_width", WINDOW_WIDTH)
self.WS = self.args.get("window_stride", WINDOW_STRIDE)
self.limit_output_length = self.args.get("limit_output_length", False)
self.num_classes = len(data_config["mapping"])
self.output_length = data_config["output_dims"][0]
cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW)
cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}}
self.cnn = CNN(data_config=cnn_data_config, args=args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the LineCNN to an input image and return logits.
Parameters
----------
x
(B, C, H, W) input image with H equal to IMAGE_SIZE
Returns
-------
torch.Tensor
(B, C, S) logits, where S is the length of the sequence and C is the number of classes
S can be computed from W and CHAR_WIDTH
C is self.num_classes
"""
B, _C, H, W = x.shape
assert H == IMAGE_SIZE # Make sure we can use our CNN class
# Compute number of windows
S = math.floor((W - self.WW) / self.WS + 1)
# NOTE: type_as properly sets device
activations = torch.zeros((B, self.num_classes, S)).type_as(x)
for s in range(S):
start_w = self.WS * s
end_w = start_w + self.WW
window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW)
activations[:, :, s] = self.cnn(window)
if self.limit_output_length:
# S might not match ground truth, so let's only take enough activations as are expected
activations = activations[:, :, : self.output_length]
return activations
@staticmethod
def add_to_argparse(parser):
CNN.add_to_argparse(parser)
parser.add_argument(
"--window_width",
type=int,
default=WINDOW_WIDTH,
help="Width of the window that will slide over the input image.",
)
parser.add_argument(
"--window_stride",
type=int,
default=WINDOW_STRIDE,
help="Stride of the window that will slide over the input image.",
)
parser.add_argument("--limit_output_length", action="store_true", default=False)
return parser
================================================
FILE: lab08/text_recognizer/models/line_cnn_transformer.py
================================================
"""Model that combines a LineCNN with a Transformer model for text prediction."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
from .line_cnn import LineCNN
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding
TF_DIM = 256
TF_FC_DIM = 256
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
class LineCNNTransformer(nn.Module):
"""Process the line through a CNN and process the resulting sequence with a Transformer decoder."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# Instantiate LineCNN with "num_classes" set to self.dim
data_config_for_line_cnn = {**data_config}
data_config_for_line_cnn["mapping"] = list(range(self.dim))
self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args)
# LineCNN outputs (B, E, S) log probs, with E == dim
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.pos_encoder = PositionalEncoding(d_model=self.dim)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(Sx, B, E) logits
"""
x = self.line_cnn(x) # (B, E, Sx)
x = x * math.sqrt(self.dim)
x = x.permute(2, 0, 1) # (Sx, B, E)
x = self.pos_encoder(x) # (Sx, B, E)
return x
def decode(self, x, y):
"""Decode a batch of encoded images x using preceding ground truth y.
Parameters
----------
x
(Sx, B, E) image encoded as a sequence
y
(B, Sy) with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) logits
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Predict sequences of tokens from input images auto-regressively.
Parameters
----------
x
(B, H, W) image
Returns
-------
torch.Tensor
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1:] # Set the last output token
# Set all tokens after end token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
@staticmethod
def add_to_argparse(parser):
LineCNN.add_to_argparse(parser)
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab08/text_recognizer/models/mlp.py
================================================
import argparse
from typing import Any, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
FC1_DIM = 1024
FC2_DIM = 128
FC_DROPOUT = 0.5
class MLP(nn.Module):
"""Simple MLP suitable for recognizing single characters."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.data_config = data_config
input_dim = np.prod(self.data_config["input_dims"])
num_classes = len(self.data_config["mapping"])
fc1_dim = self.args.get("fc1", FC1_DIM)
fc2_dim = self.args.get("fc2", FC2_DIM)
dropout_p = self.args.get("fc_dropout", FC_DROPOUT)
self.fc1 = nn.Linear(input_dim, fc1_dim)
self.dropout = nn.Dropout(dropout_p)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.fc3 = nn.Linear(fc2_dim, num_classes)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--fc1", type=int, default=FC1_DIM)
parser.add_argument("--fc2", type=int, default=FC2_DIM)
parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT)
return parser
================================================
FILE: lab08/text_recognizer/models/resnet_transformer.py
================================================
"""Model combining a ResNet with a Transformer for image-to-sequence tasks."""
import argparse
import math
from typing import Any, Dict
import torch
from torch import nn
import torchvision
from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage
TF_DIM = 256
TF_FC_DIM = 1024
TF_DROPOUT = 0.4
TF_LAYERS = 4
TF_NHEAD = 4
RESNET_DIM = 512 # hard-coded
class ResnetTransformer(nn.Module):
"""Pass an image through a Resnet and decode the resulting embedding with a Transformer."""
def __init__(
self,
data_config: Dict[str, Any],
args: argparse.Namespace = None,
) -> None:
super().__init__()
self.data_config = data_config
self.input_dims = data_config["input_dims"]
self.num_classes = len(data_config["mapping"])
self.mapping = data_config["mapping"]
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
self.start_token = inverse_mapping[" "]
self.max_output_length = data_config["output_dims"][0]
self.args = vars(args) if args is not None else {}
self.dim = self.args.get("tf_dim", TF_DIM)
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
tf_layers = self.args.get("tf_layers", TF_LAYERS)
# ## Encoder part - should output vector sequence of length self.dim per sample
resnet = torchvision.models.resnet18(weights=None)
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
# encoder_projection will output (B, dim, _H, _W) logits
self.enc_pos_encoder = PositionalEncodingImage(
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
) # Max (Ho, Wo)
# ## Decoder part
self.embedding = nn.Embedding(self.num_classes, self.dim)
self.fc = nn.Linear(self.dim, self.num_classes)
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
num_layers=tf_layers,
)
self.init_weights() # This is empirically important
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Autoregressively produce sequences of labels from input images.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
output_tokens
(B, Sy) with elements in [0, C-1] where C is num_classes
"""
B = x.shape[0]
S = self.max_output_length
x = self.encode(x) # (Sx, B, E)
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy)
output_tokens[:, 0] = self.start_token # Set start token
for Sy in range(1, S):
y = output_tokens[:, :Sy] # (B, Sy)
output = self.decode(x, y) # (Sy, B, C)
output = torch.argmax(output, dim=-1) # (Sy, B)
output_tokens[:, Sy] = output[-1] # Set the last output token
# Early stopping of prediction loop to speed up prediction
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
break
# Set all tokens after end or padding token to be padding
for Sy in range(1, S):
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
output_tokens[ind, Sy] = self.padding_token
return output_tokens # (B, Sy)
def init_weights(self):
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
self.fc.weight.data.uniform_(-initrange, initrange)
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
if self.encoder_projection.bias is not None:
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data)
bound = 1 / math.sqrt(fan_out)
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode each image tensor in a batch into a sequence of embeddings.
Parameters
----------
x
(B, Ch, H, W) image, where Ch == 1 or Ch == 3
Returns
-------
(Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps
"""
_B, C, _H, _W = x.shape
if C == 1:
x = x.repeat(1, 3, 1, 1)
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
return x
def decode(self, x, y):
"""Decode a batch of encoded images x with guiding sequences y.
During autoregressive inference, the guiding sequence will be previous predictions.
During training, the guiding sequence will be the ground truth.
Parameters
----------
x
(Sx, B, E) images encoded as sequences of embeddings
y
(B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes
Returns
-------
torch.Tensor
(Sy, B, C) batch of logit sequences
"""
y_padding_mask = y == self.padding_token
y = y.permute(1, 0) # (Sy, B)
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
y = self.dec_pos_encoder(y) # (Sy, B, E)
Sy = y.shape[0]
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
output = self.transformer_decoder(
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
) # (Sy, B, E)
output = self.fc(output) # (Sy, B, C)
return output
@staticmethod
def add_to_argparse(parser):
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
return parser
================================================
FILE: lab08/text_recognizer/models/transformer_util.py
================================================
"""Position Encoding and other utilities for Transformers."""
import math
import torch
from torch import Tensor
import torch.nn as nn
class PositionalEncodingImage(nn.Module):
"""
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
"""
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None:
super().__init__()
self.d_model = d_model
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
return pe
def forward(self, x: Tensor) -> Tensor:
"""pytorch.nn.module.forward"""
# x.shape = (B, d_model, H, W)
assert x.shape[1] == self.pe.shape[0] # type: ignore
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
return x
class PositionalEncoding(torch.nn.Module):
"""Classic Attention-is-all-you-need positional encoding."""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None:
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
self.register_buffer(
"pe", pe, persistent=persistent
) # not necessary to persist in state_dict, since it can be remade
@staticmethod
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x.shape = (S, B, d_model)
assert x.shape[2] == self.pe.shape[2] # type: ignore
x = x + self.pe[: x.size(0)] # type: ignore
return self.dropout(x)
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
"""Generate a triangular (size, size) mask."""
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
return mask
================================================
FILE: lab08/text_recognizer/paragraph_text_recognizer.py
================================================
"""Detects a paragraph of text in an input image.
Example usage as a script:
python text_recognizer/paragraph_text_recognizer.py \
text_recognizer/tests/support/paragraphs/a01-077.png
python text_recognizer/paragraph_text_recognizer.py \
https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png
"""
import argparse
from pathlib import Path
from typing import Sequence, Union
from PIL import Image
import torch
from text_recognizer import util
from text_recognizer.stems.paragraph import ParagraphStem
STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "paragraph-text-recognizer"
MODEL_FILE = "model.pt"
class ParagraphTextRecognizer:
"""Recognizes a paragraph of text in an image."""
def __init__(self, model_path=None):
if model_path is None:
model_path = STAGED_MODEL_DIRNAME / MODEL_FILE
self.model = torch.jit.load(model_path)
self.mapping = self.model.mapping
self.ignore_tokens = self.model.ignore_tokens
self.stem = ParagraphStem()
@torch.no_grad()
def predict(self, image: Union[str, Path, Image.Image]) -> str:
"""Predict/infer text in input image (which can be a file path or url)."""
image_pil = image
if not isinstance(image, Image.Image):
image_pil = util.read_image_pil(image, grayscale=True)
image_tensor = self.stem(image_pil).unsqueeze(axis=0)
y_pred = self.model(image_tensor)[0]
pred_str = convert_y_label_to_string(y=y_pred, mapping=self.mapping, ignore_tokens=self.ignore_tokens)
return pred_str
def convert_y_label_to_string(y: torch.Tensor, mapping: Sequence[str], ignore_tokens: Sequence[int]) -> str:
return "".join([mapping[i] for i in y if i not in ignore_tokens])
def main():
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
parser.add_argument(
"filename",
type=str,
help="Name for an image file. This can be a local path, a URL, a URI from AWS/GCP/Azure storage, an HDFS path, or any other resource locator supported by the smart_open library.",
)
args = parser.parse_args()
text_recognizer = ParagraphTextRecognizer()
pred_str = text_recognizer.predict(args.filename)
print(pred_str)
if __name__ == "__main__":
main()
================================================
FILE: lab08/text_recognizer/stems/image.py
================================================
import torch
from torchvision import transforms
class ImageStem:
"""A stem for models operating on images.
Images are presumed to be provided as PIL images,
as is standard for torchvision Datasets.
Transforms are split into two categories:
pil_transforms, which take in and return PIL images, and
torch_transforms, which take in and return Torch tensors.
By default, these two transforms are both identities.
In between, the images are mapped to tensors.
The torch_transforms are wrapped in a torch.nn.Sequential
and so are compatible with torchscript if the underyling
Modules are compatible.
"""
def __init__(self):
self.pil_transforms = transforms.Compose([])
self.pil_to_tensor = transforms.ToTensor()
self.torch_transforms = torch.nn.Sequential()
def __call__(self, img):
img = self.pil_transforms(img)
img = self.pil_to_tensor(img)
with torch.no_grad():
img = self.torch_transforms(img)
return img
class MNISTStem(ImageStem):
"""A stem for handling images from the MNIST dataset."""
def __init__(self):
super().__init__()
self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,)))
================================================
FILE: lab08/text_recognizer/stems/line.py
================================================
import random
from PIL import Image
from torchvision import transforms
import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.stems.image import ImageStem
class LineStem(ImageStem):
"""A stem for handling images containing a line of text."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.5, 1)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"translate": (0, 0.05),
"scale": (0.4, 1.1),
"shear": (-40, 50),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
if augment:
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
)
class IAMLineStem(ImageStem):
"""A stem for handling images containing lines of text from the IAMLines dataset."""
def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None):
super().__init__()
def embed_crop(crop, augment=augment):
# crop is PIL.image of dtype="L" (so values range from 0 -> 255)
image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
# Resize crop
crop_width, crop_height = crop.size
new_crop_height = metadata.IMAGE_HEIGHT
new_crop_width = int(new_crop_height * (crop_width / crop_height))
if augment:
# Add random stretching
new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR)
# Embed in the image
x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
y = metadata.IMAGE_HEIGHT - new_crop_height
image.paste(crop_resized, (x, y))
return image
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": (0.8, 1.6)}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 1,
"shear": (-30, 20),
"interpolation": transforms.InterpolationMode.BILINEAR,
"fill": 0,
}
pil_transforms_list = [transforms.Lambda(embed_crop)]
if augment:
pil_transforms_list += [
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomAffine(**random_affine_kwargs),
]
self.pil_transforms = transforms.Compose(pil_transforms_list)
================================================
FILE: lab08/text_recognizer/stems/paragraph.py
================================================
"""IAMParagraphs Stem class."""
import torchvision.transforms as transforms
import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.stems.image import ImageStem
IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE
MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
class ParagraphStem(ImageStem):
"""A stem for handling images that contain a paragraph of text."""
def __init__(
self,
augment=False,
color_jitter_kwargs=None,
random_affine_kwargs=None,
random_perspective_kwargs=None,
gaussian_blur_kwargs=None,
sharpness_kwargs=None,
):
super().__init__()
if not augment:
self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)])
else:
if color_jitter_kwargs is None:
color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
if random_affine_kwargs is None:
random_affine_kwargs = {
"degrees": 3,
"shear": 6,
"scale": (0.95, 1),
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if random_perspective_kwargs is None:
random_perspective_kwargs = {
"distortion_scale": 0.2,
"p": 0.5,
"interpolation": transforms.InterpolationMode.BILINEAR,
}
if gaussian_blur_kwargs is None:
gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
if sharpness_kwargs is None:
sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
# IMAGE_SHAPE is (576, 640)
self.pil_transforms = transforms.Compose(
[
transforms.ColorJitter(**color_jitter_kwargs),
transforms.RandomCrop(
size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant"
),
transforms.RandomAffine(**random_affine_kwargs),
transforms.RandomPerspective(**random_perspective_kwargs),
transforms.GaussianBlur(**gaussian_blur_kwargs),
transforms.RandomAdjustSharpness(**sharpness_kwargs),
]
)
================================================
FILE: lab08/text_recognizer/tests/test_callback_utils.py
================================================
"""Tests for the text_recognizer.callbacks.util module."""
import random
import string
import tempfile
import pytorch_lightning as pl
from text_recognizer.callbacks.util import check_and_warn
def test_check_and_warn_simple():
"""Test the success and failure in the case of a simple class we control."""
class Foo:
pass # a class with no special attributes
letters = string.ascii_lowercase
random_attribute = "".join(random.choices(letters, k=10))
assert check_and_warn(Foo(), random_attribute, "random feature")
assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects")
def test_check_and_warn_tblogger():
"""Test that we return a truthy value when trying to log tables with TensorBoard.
We added check_and_warn in order to prevent a crash if this happens.
"""
tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory())
assert check_and_warn(tblogger, "log_table", "tables")
def test_check_and_warn_wandblogger():
"""Test that we return a falsy value when we try to log tables with W&B.
In adding check_and_warn, we don't want to block the feature in the happy path.
"""
wandblogger = pl.loggers.WandbLogger(anonymous=True)
assert not check_and_warn(wandblogger, "log_table", "tables")
================================================
FILE: lab08/text_recognizer/tests/test_iam.py
================================================
"""Test for data.iam module."""
from text_recognizer.data.iam import IAM
def test_iam_parsed_lines():
"""Tests that we retrieve the same number of line labels and line image cropregions."""
iam = IAM()
iam.prepare_data()
for iam_id in iam.all_ids:
assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id])
def test_iam_data_splits():
"""Fails when any identifiers are shared between training, test, or validation."""
iam = IAM()
iam.prepare_data()
assert not set(iam.train_ids) & set(iam.validation_ids)
assert not set(iam.train_ids) & set(iam.test_ids)
assert not set(iam.validation_ids) & set(iam.test_ids)
================================================
FILE: lab08/text_recognizer/util.py
================================================
"""Utility functions for text_recognizer module."""
import base64
import contextlib
import hashlib
from io import BytesIO
import os
from pathlib import Path
from typing import Union
from urllib.request import urlretrieve
import numpy as np
from PIL import Image
import smart_open
from tqdm import tqdm
def to_categorical(y, num_classes):
"""1-hot encode a tensor."""
return np.eye(num_classes, dtype="uint8")[y]
def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image:
with smart_open.open(image_uri, "rb") as image_file:
return read_image_pil_file(image_file, grayscale)
def read_image_pil_file(image_file, grayscale=False) -> Image:
with Image.open(image_file) as image:
if grayscale:
image = image.convert(mode="L")
else:
image = image.convert(mode=image.mode)
return image
@contextlib.contextmanager
def temporary_working_directory(working_dir: Union[str, Path]):
"""Temporarily switches to a directory, then returns to the original directory on exit."""
curdir = os.getcwd()
os.chdir(working_dir)
try:
yield
finally:
os.chdir(curdir)
def read_b64_image(b64_string, grayscale=False):
"""Load base64-encoded images."""
try:
image_file = read_b64_string(b64_string)
return read_image_pil_file(image_file, grayscale)
except Exception as exception:
raise ValueError("Could not load image from b64 {}: {}".format(b64_string, exception)) from exception
def read_b64_string(b64_string, return_data_type=False):
"""Read a base64-encoded string into an in-memory file-like object."""
data_header, b64_data = split_and_validate_b64_string(b64_string)
b64_buffer = BytesIO(base64.b64decode(b64_data))
if return_data_type:
return get_b64_filetype(data_header), b64_buffer
else:
return b64_buffer
def get_b64_filetype(data_header):
"""Retrieves the filetype information from the data type header of a base64-encoded object."""
_, file_type = data_header.split("/")
return file_type
def split_and_validate_b64_string(b64_string):
"""Return the data_type and data of a b64 string, with validation."""
header, data = b64_string.split(",", 1)
assert header.startswith("data:")
assert header.endswith(";base64")
data_type = header.split(";")[0].split(":")[1]
return data_type, data
def encode_b64_image(image, format="png"):
"""Encode a PIL image as a base64 string."""
_buffer = BytesIO() # bytes that live in memory
image.save(_buffer, format=format) # but which we write to like a file
encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8")
return encoded_image
def compute_sha256(filename: Union[Path, str]):
"""Return SHA256 checksum of a file."""
with open(filename, "rb") as f:
return hashlib.sha256(f.read()).hexdigest()
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename):
"""Download a file from url to filename, with a progress bar."""
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310
================================================
FILE: lab08/training/__init__.py
================================================
================================================
FILE: lab08/training/cleanup_artifacts.py
================================================
"""Removes artifacts from projects and runs.
Artifacts are binary files that we want to track
and version but don't want to include in git,
generally because they are too large,
because they don't have meaningful diffs,
or because they change more quickly than code.
During development, we often generate artifacts
that we don't really need, e.g. model weights for
an overfitting test run. Space on artifact storage
is generally very large, but it is limited,
so we should occasionally delete unneeded artifacts
to reclaim some of that space.
For usage help, run
python training/cleanup_artifacts.py --help
"""
import argparse
import wandb
api = wandb.Api()
DEFAULT_PROJECT = "fsdl-text-recognizer-2022-training"
DEFAULT_ENTITY = api.default_entity
def _setup_parser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--entity",
type=str,
default=None,
help="The entity from which to remove artifacts. Provide the value DEFAULT "
+ f"to use the default WANDB_ENTITY, which is currently {DEFAULT_ENTITY}.",
)
parser.add_argument(
"--project",
type=str,
default=DEFAULT_PROJECT,
help=f"The project from which to remove artifacts. Default is {DEFAULT_PROJECT}",
)
parser.add_argument(
"--run_ids",
type=str,
default=None,
nargs="*",
help="One or more run IDs from which to remove artifacts. Default is None.",
)
parser.add_argument(
"--run_name_res",
type=str,
default=None,
nargs="*",
help="One or more regular expressions to use to select runs (by display name) from which to remove artifacts. See wandb.Api.runs documentation for details on the syntax. Beware that this is a footgun and consider using interactively with --dryrun and -v. Default is None.",
metavar="RUN_NAME_REGEX",
)
flags = parser.add_mutually_exclusive_group()
flags.add_argument("--all", action="store_true", help="Delete all artifacts from selected runs.")
flags.add_argument(
"--no-alias", action="store_true", help="Delete all artifacts without an alias from selected runs."
)
flags.add_argument(
"--aliases",
type=str,
nargs="*",
help="Delete artifacts that have any of the aliases from the provided list from selected runs.",
)
parser.add_argument(
"-v",
action="store_true",
dest="verbose",
help="Display information about targeted entities, projects, runs, and artifacts.",
)
parser.add_argument(
"--dryrun",
action="store_true",
help="Select artifacts without deleting them and display which artifacts were selected.",
)
return parser
def main(args):
entity = _get_entity_from(args)
project_path = f"{entity}/{args.project}"
runs = _get_runs(project_path, args.run_ids, args.run_name_res, verbose=args.verbose)
artifact_selector = _get_selector_from(args)
protect_aliases = args.no_alias # avoid deletion of any aliased artifacts
for run in runs:
clean_run_artifacts(
run, selector=artifact_selector, protect_aliases=protect_aliases, verbose=args.verbose, dryrun=args.dryrun
)
def clean_run_artifacts(run, selector, protect_aliases=True, verbose=False, dryrun=True):
artifacts = run.logged_artifacts()
for artifact in artifacts:
if selector(artifact):
remove_artifact(artifact, protect_aliases=protect_aliases, verbose=verbose, dryrun=dryrun)
def remove_artifact(artifact, protect_aliases, verbose=False, dryrun=True):
project, entity, id = artifact.project, artifact.entity, artifact.id
type, aliases = artifact.type, artifact.aliases
if verbose or dryrun:
print(f"selecting for deletion artifact {project}/{entity}/{id} of type {type} with aliases {aliases}")
if not dryrun:
artifact.delete(delete_aliases=not protect_aliases)
def _get_runs(project_path, run_ids=None, run_name_res=None, verbose=False):
if run_ids is None:
run_ids = []
if run_name_res is None:
run_name_res = []
runs = []
for run_id in run_ids:
runs.append(_get_run_by_id(project_path, run_id, verbose=verbose))
for run_name_re in run_name_res:
runs += _get_runs_by_name_re(project_path, run_name_re, verbose=verbose)
return runs
def _get_run_by_id(project_path, run_id, verbose=False):
path = f"{project_path}/{run_id}"
run = api.run(path)
if verbose:
print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}")
return run
def _get_runs_by_name_re(project_path, run_name_re, verbose=False):
matching_runs = api.runs(path=project_path, filters={"display_name": {"$regex": run_name_re}})
if verbose:
for run in matching_runs:
print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}")
return matching_runs
def _get_selector_from(args, verbose=False):
if args.all:
if verbose:
print("removing all artifacts from matching runs")
return lambda _: True
if args.no_alias:
if verbose:
print("removing all artifacts with no aliases from matching runs")
return lambda artifact: artifact.aliases == []
if args.aliases:
if verbose:
print(f"removing all artifacts with any of {args.aliases} in aliases from matching runs")
return lambda artifact: any(alias in artifact.aliases for alias in args.aliases)
if verbose:
print("removing no artifacts matching runs")
return lambda _: False
def _get_entity_from(args, verbose=False):
entity = args.entity
if entity is None:
raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.")
elif entity == "DEFAULT":
entity = DEFAULT_ENTITY
if verbose:
print(f"using default entity {entity}")
else:
if verbose:
print(f"using entity {entity}")
return entity
if __name__ == "__main__":
parser = _setup_parser()
args = parser.parse_args()
main(args)
================================================
FILE: lab08/training/run_experiment.py
================================================
"""Experiment-running framework."""
import argparse
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
import torch
from text_recognizer import callbacks as cb
from text_recognizer import lit_models
from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args
# In order to ensure reproducible experiments, we must set random seeds.
np.random.seed(42)
torch.manual_seed(42)
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args"
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
parser.set_defaults(max_epochs=1)
# Basic arguments
parser.add_argument(
"--wandb",
action="store_true",
default=False,
help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.",
)
parser.add_argument(
"--data_class",
type=str,
default="MNIST",
help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.",
)
parser.add_argument(
"--model_class",
type=str,
default="MLP",
help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.",
)
parser.add_argument(
"--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path."
)
parser.add_argument(
"--stop_early",
type=int,
default=0,
help="If non-zero, applies early stopping, with the provided value as the 'patience' argument."
+ " Default is 0.",
)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
@rank_zero_only
def _ensure_logging_dir(experiment_dir):
"""Create the logging directory via the rank-zero process, if necessary."""
Path(experiment_dir).mkdir(parents=True, exist_ok=True)
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
For basic help documentation, run the command
```
python training/run_experiment.py --help
```
The available command line args differ depending on some of the arguments, including --model_class and --data_class.
To see which command line args are available and read their documentation, provide values for those arguments
before invoking --help, like so:
```
python training/run_experiment.py --model_class=MLP --data_class=MNIST --help
"""
parser = _setup_parser()
args = parser.parse_args()
data, model = setup_data_and_model_from_args(args)
lit_model_class = lit_models.BaseLitModel
if args.loss == "transformer":
lit_model_class = lit_models.TransformerLitModel
if args.load_checkpoint is not None:
lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model)
else:
lit_model = lit_model_class(args=args, model=model)
log_dir = Path("training") / "logs"
_ensure_logging_dir(log_dir)
logger = pl.loggers.TensorBoardLogger(log_dir)
experiment_dir = logger.log_dir
goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss"
filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}"
if goldstar_metric == "validation/cer":
filename_format += "-validation.cer={validation/cer:.3f}"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
save_top_k=5,
filename=filename_format,
monitor=goldstar_metric,
mode="min",
auto_insert_metric_name=False,
dirpath=experiment_dir,
every_n_epochs=args.check_val_every_n_epoch,
)
summary_callback = pl.callbacks.ModelSummary(max_depth=2)
callbacks = [summary_callback, checkpoint_callback]
if args.wandb:
logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train")
logger.watch(model, log_freq=max(100, args.log_every_n_steps))
logger.log_hyperparams(vars(args))
experiment_dir = logger.experiment.dir
callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()]
if args.stop_early:
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor="validation/loss", mode="min", patience=args.stop_early
)
callbacks.append(early_stopping_callback)
if args.wandb and args.loss in ("transformer",):
callbacks.append(cb.ImageToTextLogger())
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger)
if args.profile:
sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0)
profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir)
profiler.STEP_FUNCTIONS = {"training_step"} # only profile training
else:
profiler = pl.profiler.PassThroughProfiler()
trainer.profiler = profiler
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
rank_zero_info(f"Best model saved at: {best_model_path}")
if args.wandb:
rank_zero_info("Best model also uploaded to W&B ")
trainer.test(datamodule=data, ckpt_path=best_model_path)
else:
trainer.test(lit_model, datamodule=data)
if __name__ == "__main__":
main()
================================================
FILE: lab08/training/stage_model.py
================================================
"""Stages a model for use in production.
If based on a checkpoint, the model is converted to torchscript, saved locally,
and uploaded to W&B.
If based on a model that is already converted and uploaded, the model file is downloaded locally.
For details on how the W&B artifacts backing the checkpoints and models are handled,
see the documenation for stage_model.find_artifact.
"""
import argparse
from pathlib import Path
import tempfile
import torch
import wandb
from text_recognizer.lit_models import TransformerLitModel
from training.util import setup_data_and_model_from_args
# these names are all set by the pl.loggers.WandbLogger
MODEL_CHECKPOINT_TYPE = "model"
BEST_CHECKPOINT_ALIAS = "best"
MODEL_CHECKPOINT_PATH = "model.ckpt"
LOG_DIR = Path("training") / "logs"
STAGED_MODEL_TYPE = "prod-ready" # we can choose the name of this type, and ideally it's different from checkpoints
STAGED_MODEL_FILENAME = "model.pt" # standard nomenclature; pytorch_model.bin is also used
PROJECT_ROOT = Path(__file__).resolve().parents[1]
LITMODEL_CLASS = TransformerLitModel
api = wandb.Api()
DEFAULT_ENTITY = api.default_entity
DEFAULT_FROM_PROJECT = "fsdl-text-recognizer-2022-training"
DEFAULT_TO_PROJECT = "fsdl-text-recognizer-2022-training"
DEFAULT_STAGED_MODEL_NAME = "paragraph-text-recognizer"
PROD_STAGING_ROOT = PROJECT_ROOT / "text_recognizer" / "artifacts"
def main(args):
prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name
prod_staging_directory.mkdir(exist_ok=True, parents=True)
entity = _get_entity_from(args)
# if we're just fetching an already compiled model
if args.fetch:
# find it and download it
staged_model = f"{entity}/{args.from_project}/{args.staged_model_name}:latest"
artifact = download_artifact(staged_model, prod_staging_directory)
print_info(artifact)
return # and we're done
# otherwise, we'll need to download the weights, compile the model, and save it
with wandb.init(
job_type="stage", project=args.to_project, dir=LOG_DIR
): # log staging to W&B so prod and training are connected
# find the model checkpoint and retrieve its artifact name and an api handle
ckpt_at, ckpt_api = find_artifact(
entity, args.from_project, type=MODEL_CHECKPOINT_TYPE, alias=args.ckpt_alias, run=args.run
)
# get the run that produced that checkpoint
logging_run = get_logging_run(ckpt_api)
print_info(ckpt_api, logging_run)
metadata = get_checkpoint_metadata(logging_run, ckpt_api)
# create an artifact for the staged, deployable model
staged_at = wandb.Artifact(args.staged_model_name, type=STAGED_MODEL_TYPE, metadata=metadata)
with tempfile.TemporaryDirectory() as tmp_dir:
# download the checkpoint to a temporary directory
download_artifact(ckpt_at, tmp_dir)
# reload the model from that checkpoint
model = load_model_from_checkpoint(metadata, directory=tmp_dir)
# save the model to torchscript in the staging directory
save_model_to_torchscript(model, directory=prod_staging_directory)
# upload the staged model so it can be downloaded elsewhere
upload_staged_model(staged_at, from_directory=prod_staging_directory)
def find_artifact(entity: str, project: str, type: str, alias: str, run=None):
"""Finds the artifact of a given type with a given alias under the entity and project.
Parameters
----------
entity
The name of the W&B entity under which the artifact is logged.
project
The name of the W&B project under which the artifact is logged.
type
The name of the type of the artifact.
alias : str
The alias for this artifact. This alias must be unique within the
provided type for the run, if provided, or for the project,
if the run is not provided.
run : str
Optionally, the run in which the artifact is located.
Returns
-------
Tuple[path, artifact]
An identifying path and an API handle for a matching artifact.
"""
if run is not None:
path = _find_artifact_run(entity, project, type=type, run=run, alias=alias)
else:
path = _find_artifact_project(entity, project, type=type, alias=alias)
return path, api.artifact(path)
def get_logging_run(artifact):
api_run = artifact.logged_by()
return api_run
def print_info(artifact, run=None):
if run is None:
run = get_logging_run(artifact)
full_artifact_name = f"{artifact.entity}/{artifact.project}/{artifact.name}"
print(f"Using artifact {full_artifact_name}")
artifact_url_prefix = f"https://wandb.ai/{artifact.entity}/{artifact.project}/artifacts/{artifact.type}"
artifact_url_suffix = f"{artifact.name.replace(':', '/')}"
print(f"View at URL: {artifact_url_prefix}/{artifact_url_suffix}")
print(f"Logged by {run.name} -- {run.project}/{run.entity}/{run.id}")
print(f"View at URL: {run.url}")
def get_checkpoint_metadata(run, checkpoint):
config = run.config
out = {"config": config}
try:
ckpt_filename = checkpoint.metadata["original_filename"]
out["original_filename"] = ckpt_filename
metric_key = checkpoint.metadata["ModelCheckpoint"]["monitor"]
metric_score = checkpoint.metadata["score"]
out[metric_key] = metric_score
except KeyError:
pass
return out
def download_artifact(artifact_path, target_directory):
"""Downloads the artifact at artifact_path to the target directory."""
if wandb.run is not None: # if we are inside a W&B run, track that we used this artifact
artifact = wandb.use_artifact(artifact_path)
else: # otherwise, just download the artifact via the API
artifact = api.artifact(artifact_path)
artifact.download(root=target_directory)
return artifact
def load_model_from_checkpoint(ckpt_metadata, directory):
config = ckpt_metadata["config"]
args = argparse.Namespace(**config)
_, model = setup_data_and_model_from_args(args)
# load LightningModule from checkpoint
pth = Path(directory) / MODEL_CHECKPOINT_PATH
lit_model = LITMODEL_CLASS.load_from_checkpoint(checkpoint_path=pth, args=args, model=model, strict=False)
lit_model.eval()
return lit_model
def save_model_to_torchscript(model, directory):
scripted_model = model.to_torchscript(method="script", file_path=None)
path = Path(directory) / STAGED_MODEL_FILENAME
torch.jit.save(scripted_model, path)
def upload_staged_model(staged_at, from_directory):
staged_at.add_file(Path(from_directory) / STAGED_MODEL_FILENAME)
wandb.log_artifact(staged_at)
def _find_artifact_run(entity, project, type, run, alias):
run_name = f"{entity}/{project}/{run}"
api_run = api.run(run_name)
artifacts = api_run.logged_artifacts()
match = [art for art in artifacts if alias in art.aliases and art.type == type]
if not match:
raise ValueError(f"No artifact with alias {alias} found at {run_name} of type {type}")
if len(match) > 1:
raise ValueError(f"Multiple artifacts ({len(match)}) with alias {alias} found at {run_name} of type {type}")
return f"{entity}/{project}/{match[0].name}"
def _find_artifact_project(entity, project, type, alias):
project_name = f"{entity}/{project}"
api_project = api.project(project, entity=entity)
api_artifact_types = api_project.artifacts_types()
# loop through all artifact types in this project
for artifact_type in api_artifact_types:
if artifact_type.name != type:
continue # skipping those that don't match type
collections = artifact_type.collections()
# loop through all artifacts and their versions
for collection in collections:
versions = collection.versions()
for version in versions:
if alias in version.aliases: # looking for the first one that matches the alias
return f"{project_name}/{version.name}"
raise ValueError(f"Artifact with alias {alias} not found in type {type} in {project_name}")
raise ValueError(f"Artifact type {type} not found. {project_name} could be private or not exist.")
def _get_entity_from(args):
entity = args.entity
if entity is None:
raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.")
elif entity == "DEFAULT":
entity = DEFAULT_ENTITY
return entity
def _setup_parser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--fetch",
action="store_true",
help=f"If provided, check ENTITY/FROM_PROJECT for an artifact with the provided STAGED_MODEL_NAME and download its latest version to {PROD_STAGING_ROOT}/STAGED_MODEL_NAME.",
)
parser.add_argument(
"--entity",
type=str,
default=None,
help=f"Entity from which to download the checkpoint. Note that checkpoints are always uploaded to the logged-in wandb entity. Pass the value 'DEFAULT' to also download from default entity, which is currently {DEFAULT_ENTITY}.",
)
parser.add_argument(
"--from_project",
type=str,
default=DEFAULT_FROM_PROJECT,
help=f"Project from which to download the checkpoint. Default is {DEFAULT_FROM_PROJECT}",
)
parser.add_argument(
"--to_project",
type=str,
default=DEFAULT_TO_PROJECT,
help=f"Project to which to upload the compiled model. Default is {DEFAULT_TO_PROJECT}.",
)
parser.add_argument(
"--run",
type=str,
default=None,
help=f"Optionally, the name of a run to check for an artifact of type {MODEL_CHECKPOINT_TYPE} that has the provided CKPT_ALIAS. Default is None.",
)
parser.add_argument(
"--ckpt_alias",
type=str,
default=BEST_CHECKPOINT_ALIAS,
help=f"Alias that identifies which model checkpoint should be staged.The artifact's alias can be set manually or programmatically elsewhere. Default is {BEST_CHECKPOINT_ALIAS!r}.",
)
parser.add_argument(
"--staged_model_name",
type=str,
default=DEFAULT_STAGED_MODEL_NAME,
help=f"Name to give the staged model artifact. Default is {DEFAULT_STAGED_MODEL_NAME!r}.",
)
return parser
if __name__ == "__main__":
parser = _setup_parser()
args = parser.parse_args()
main(args)
================================================
FILE: lab08/training/tests/test_memorize_iam.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
# tests whether we can achieve a criterion loss
# on a single batch within a certain number of epochs
FAILURE=false
# constants and CLI args set by aiming for <5 min test on commodity GPU,
# including data download step
MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash
CRITERION="${2:-1.0}"
# train on GPU if it's available
GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))')
python ./training/run_experiment.py \
--data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \
--augment_data false --tf_dropout 0.0 \
--gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \
--log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true
python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Memorization test failed at loss criterion $CRITERION"
exit 1
fi
echo "Memorization test passed at loss criterion $CRITERION"
exit 0
================================================
FILE: lab08/training/tests/test_model_development.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
CI="${CI:-false}"
if [ "$CI" = false ]; then
export WANDB_PROJECT="fsdl-testing-2022"
else
export WANDB_PROJECT="fsdl-testing-2022-ci"
fi
echo "training smaller version of real model class on real data"
python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \
--limit_train_batches 1 --limit_val_batches 1 --limit_test_batches 1 --num_sanity_val_steps 0 \
--num_workers 1 --wandb || FAILURE=true
TRAIN_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//")
echo "staging trained model from run $TRAIN_RUN"
python training/stage_model.py --entity DEFAULT --run "$TRAIN_RUN" --staged_model_name test-dummy --ckpt_alias latest --to_project "$WANDB_PROJECT" --from_project "$WANDB_PROJECT" || FAILURE=true
echo "fetching staged model"
python training/stage_model.py --entity DEFAULT --fetch --from_project $WANDB_PROJECT --staged_model_name test-dummy || FAILURE=true
STAGE_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//")
if [ "$FAILURE" = true ]; then
echo "Model development test failed"
echo "cleaning up local files"
rm -rf text_recognizer/artifacts/test-dummy
echo "leaving remote files in place"
exit 1
fi
echo "cleaning up local and remote files"
rm -rf text_recognizer/artifacts/test-dummy
python training/cleanup_artifacts.py --entity DEFAULT --project "$WANDB_PROJECT" \
--run_ids "$TRAIN_RUN" "$STAGE_RUN" --all -v
# note: if $TRAIN_RUN and $STAGE_RUN are not set, this will fail.
# that's good because it avoids all artifacts from the project being deleted due to the --all.
echo "Model development test passed"
exit 0
================================================
FILE: lab08/training/tests/test_run_experiment.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
echo "running full loop test with CNN on fake data"
python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true
echo "running fast_dev_run test of real model class on real data"
python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \
--tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \
--fast_dev_run --num_sanity_val_steps 0 \
--num_workers 1 || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Test for run_experiment.py failed"
exit 1
fi
echo "Tests for run_experiment.py passed"
exit 0
================================================
FILE: lab08/training/util.py
================================================
"""Utilities for model development scripts: training and staging."""
import argparse
import importlib
DATA_CLASS_MODULE = "text_recognizer.data"
MODEL_CLASS_MODULE = "text_recognizer.models"
def import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'."""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def setup_data_and_model_from_args(args: argparse.Namespace):
data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}")
model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}")
data = data_class(args)
model = model_class(data_config=data.config(), args=args)
return data, model
================================================
FILE: overview.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "7yQQTA9IGDt8"
},
"source": [
"", *tokens, ""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab02/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab02/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
", "", " and ", *tokens, " and ", *tokens, "",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab03/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab03/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 03: Transformers and Paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The fundamental reasons why the Transformer is such\n",
"a powerful and popular architecture\n",
"- Core intuitions for the behavior of Transformer architectures\n",
"- How to use a convolutional encoder and a Transformer decoder to recognize\n",
"entire paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 3\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Transformers?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal in building a text recognizer is to take a two-dimensional image\n",
"and convert it into a one-dimensional sequence of characters\n",
"from some alphabet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional neural networks,\n",
"discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n",
"are great at encoding images,\n",
"taking them from their raw pixel values\n",
"to a more semantically meaningful numerical representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But how do we go from that to a sequence of letters?\n",
"And what's especially tricky:\n",
"the number of letters in an image is separable from its size.\n",
"A screenshot of this document has a much higher density of letters\n",
"than a close-up photograph of a piece of paper.\n",
"How do we get a _variable-length_ sequence of letters,\n",
"where the length need have nothing to do with the size of the input tensor?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n",
"they were\n",
"[originally introduced](https://arxiv.org/abs/1706.03762)\n",
"for transforming one sequence into another,\n",
"as in machine translation.\n",
"This makes them a natural fit for processing language.\n",
"\n",
"But they have also found success in other domains --\n",
"at the time of this writing, large transformers\n",
"dominate the\n",
"[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n",
"that has become a de facto standard for comparing models\n",
"and are finding\n",
"[application in reinforcement learning](https://arxiv.org/abs/2106.01345)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we will use a Transformer as a key component of our final architecture:\n",
"we will encode our input images with a CNN\n",
"and then read them out into a text sequence with a Transformer.\n",
"\n",
"Before trying out this new model,\n",
"let's first get an understanding of why the Transformer architecture\n",
"has become so popular by walking through its history\n",
"and then get some intuition for how it works\n",
"by looking at some\n",
"[recent work](https://transformer-circuits.pub/)\n",
"on explaining the behavior of both toy models and state-of-the-art language models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kmKqjbvd-Mj3"
},
"source": [
"## Why not convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRqkUMdM-OxU"
},
"source": [
"In the ancient beforetimes (i.e. 2016),\n",
"the best models for natural language processing were all\n",
"_recurrent_ neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional networks were also occasionally used,\n",
"but they suffered from a serious issue:\n",
"their architectural biases don't fit text.\n",
"\n",
"First, _translation equivariance_ no longer holds.\n",
"The beginning of a piece of text is often quite different from the middle,\n",
"so the absolute position matters.\n",
"\n",
"Second, _locality_ is not as important in language.\n",
"The name of a character that hasn't appeared in thousands of pages\n",
"can become salient when someone asks, \"Whatever happened to\n",
"[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n",
"\n",
"Consider interpreting a piece of text like the Python code below:\n",
"```python\n",
"def do(arg1, arg2, arg3):\n",
" a = arg1 + arg2\n",
" b = arg3[:3]\n",
" c = a * b\n",
" return c\n",
"\n",
"print(do(1, 1, \"ayy lmao\"))\n",
"```\n",
"\n",
"After a `(` we expect a `)`,\n",
"but possibly very long afterwards,\n",
"[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n",
"and similarly we expect a `]` at some point after a `[`.\n",
"\n",
"For translation variance, consider\n",
"that we interpret `*` not by\n",
"comparing it to its neighbors\n",
"but by looking at `a` and `b`.\n",
"We mix knowledge learned through experience\n",
"with new facts learned while reading --\n",
"also known as _in-context learning_.\n",
"\n",
"In a longer text,\n",
"[e.g. the one you are reading now](./lab03_transformers.ipynb),\n",
"the translation variance of text is clearer.\n",
"Every lab notebook begins with the same header,\n",
"setting up the environment,\n",
"but that header never appears elsewhere in the notebook.\n",
"Later positions need to be processed in terms of the previous entries.\n",
"\n",
"Unlike an image, we cannot simply rotate or translate our \"camera\"\n",
"and get a new valid text.\n",
"[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n",
"that can be read without regard to position."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The field of formal language theory,\n",
"which has deep mutual influence with computer science,\n",
"gives one way of explaining the issues with convolutional networks:\n",
"they can only understand languages with _finite contexts_,\n",
"where all the information can be found within a finite window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The immediate solution, drawing from the connections to computer science, is\n",
"[recursion](https://www.google.com/search?q=recursion).\n",
"A network whose output on the final entry of the sequence is a recursive function\n",
"of all the previous entries can build up knowledge\n",
"as it reads the sequence and treat early entries quite differently than it does late ones."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa6cbTlImkEh"
},
"source": [
"In pseudo-code, such a _recurrent neural network_ module might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lKtBoPnglPrW"
},
"source": [
"```python\n",
"def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n",
" next_inputs = input_module(xs[-1])\n",
" next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n",
" return output_module(next_inputs, next_hiddens)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbJPSMnEm516"
},
"source": [
"If you've had formal computer science training,\n",
"then you may be familiar with the power of recursion,\n",
"e.g. the\n",
"[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n",
"that gave its name to the now much better-known\n",
"[startup incubator](https://www.ycombinator.com/).\n",
"\n",
"The particular form of recursion used by\n",
"recurrent neural networks implements a\n",
"[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n",
"\n",
"> If you've know a lot of computer science,\n",
"you might be concerned by this connection.\n",
"What about other\n",
"[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n",
"Where are the neural network architectures for differentiable\n",
"[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n",
"Check out Graph Neural Networks,\n",
"[which implement dynamic programming](https://arxiv.org/abs/2203.15544)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "63mMTbEBpVuE"
},
"source": [
"Recurrent networks are able to achieve\n",
"[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n",
"\n",
"There are many popular recurrent architectures,\n",
"from the beefy and classic\n",
"[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n",
"and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n",
"([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n",
"all of which have roughly similar capabilities but\n",
"[some of which are easier to train](https://arxiv.org/abs/1611.09913)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwQHVTIslOku"
},
"source": [
"In the same sense that MLPs can model \"any\" feedforward function,\n",
"in principle even basic RNNs\n",
"[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n",
"\n",
"In particular they can model any\n",
"[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n",
"which is a formal way of saying that they can in principle\n",
"do anything a computer is capable of doing.\n",
"\n",
"The question is then..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3J8EoGN3pu7P"
},
"source": [
"## Why aren't we all using RNNs?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDwNWaevpt_3"
},
"source": [
"The guarantees that MLPs can model any function\n",
"or that RNNs can model Turing machines\n",
"provide decent intuition but are not directly practically useful.\n",
"Among other reasons, they don't guarantee learnability --\n",
"that starting from random parameters we can find the parameters\n",
"that implement a given function.\n",
"The\n",
"[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n",
"than would seem from basic theoretical and empirical analysis.\n",
"\n",
"One way of understanding capacity to model language is\n",
"[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n",
"In this model of formal languages,\n",
"Turing machines sit at the top\n",
"([practically speaking](https://arxiv.org/abs/math/0209332)).\n",
"\n",
"With better mathematical models,\n",
"RNNs and LSTMs can be shown to be\n",
"[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n",
"with RNNs looking more like\n",
"[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n",
"and LSTMs coming in\n",
"[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n",
"\n",
"More controversially:\n",
"the Chomsky hierarchy is great for understanding syntax and grammar,\n",
"which makes it great for building parsers\n",
"and working with formal languages,\n",
"but the goal in _natural_ language processing is to understand _natural_ language.\n",
"Most humans' natural language is far from strictly grammatical,\n",
"but that doesn't mean it is nonsense.\n",
"\n",
"And to really \"understand\" language means\n",
"to understand its semantic content, which is fuzzy.\n",
"The most important thing for handling the fuzzy semantic content\n",
"of language is not whether you can recall\n",
"[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n",
"but whether you can model probabilistic relationships between concepts\n",
"in addition to grammar and syntax."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These both leave theoretical room for improvement over current recurrent\n",
"language and sequence models.\n",
"\n",
"But the real cause of the rise of Transformers is that..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dsu1ebvAp-3Z"
},
"source": [
"## Transformers are designed to train fast at scale on contemporary hardware."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4abU5adsPGs"
},
"source": [
"The Transformer architecture has several important features,\n",
"discussed below,\n",
"but one of the most important reasons why it is successful\n",
"is because it can be more easily trained at scale.\n",
"\n",
"This scalability is the focus of the discussion in the paper\n",
"that introduced the architecture,\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"and\n",
"[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n",
"\n",
"The recursion in RNNs is inherently sequential:\n",
"the dependence on the outputs from earlier in the sequence\n",
"means computations within an example cannot be parallelized.\n",
"\n",
"So RNNs must batch across examples to scale,\n",
"but as sequence length grows this hits memorybandwidth limits.\n",
"Serving up large batches quickly with good randomness guarantees\n",
"is also hard to optimize,\n",
"especially in distributed settings.\n",
"\n",
"The Transformer architecture,\n",
"on the other hand,\n",
"can be readily parallelized within a single example sequence,\n",
"in addition to parallelization across batches.\n",
"This can lead to massive performance gains for a fixed scale,\n",
"which means larger, higher capacity models\n",
"can be trained on larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Mzk2haFC_G1"
},
"source": [
"How does the architecture achieve this parallelizability?\n",
"\n",
"Let's start with the architecture diagram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u59eu4snLQfp"
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n",
"\n",
"display.Image(url=base_url + \"/aiayn-figure-1.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ez-XEQ7M0UlR"
},
"source": [
"> To head off a bit of confusion\n",
" in case you've worked with Transformer architectures before:\n",
" the original \"Transformer\" is an encoder/decoder architecture.\n",
" Many LLMs, like GPT models, are decoder only,\n",
" because this has turned out to scale well,\n",
" and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n",
" it's all text anyways.\n",
" We, however, will be using them across modalities,\n",
" so we need an explicit encoder,\n",
" as above. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ok4ksBi4vp89"
},
"source": [
"First focusing on the encoder (left):\n",
"the encoding at a given position is a function of all previous inputs.\n",
"But it is not a function of the previous _encodings_:\n",
"we produce the encodings \"all at once\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RPN7C-_OqzHP"
},
"source": [
"The decoder (right) does use previous \"outputs\" as its inputs,\n",
"but those outputs are not the vectors of layer activations\n",
"(aka embeddings)\n",
"that are produced by the network.\n",
"They are instead the processed outputs,\n",
"after a `softmax` and an `argmax`.\n",
"\n",
"We could obtain these outputs by processing the embeddings,\n",
"much like in a recurrent architecture.\n",
"In fact, that is one way that Transformers are run.\n",
"It's what happens in the `.forward` method\n",
"of the model we'll be training for character recognition:\n",
"`ResnetTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L5_2WMmtDnJn"
},
"source": [
"Let's look at that forward method\n",
"and connect it to the diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FR5pk4kEyCGg"
},
"outputs": [],
"source": [
"from text_recognizer.models import ResnetTransformer\n",
"\n",
"\n",
"ResnetTransformer.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-J5UFDoPzPbq"
},
"source": [
"`.encode` happens first -- that's the left side of diagram.\n",
"\n",
"The encoder can in principle be anything\n",
"that produces a sequence of fixed-length vectors,\n",
"but here it's\n",
"[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n",
"\n",
"Then we start iterating over the sequence\n",
"in the `for` loop.\n",
"\n",
"Focus on the first few lines of code.\n",
"We apply `.decode` (right side of diagram)\n",
"to the outputs so far.\n",
"\n",
"Once we have a new `output`, we apply `.argmax`\n",
"to turn the logits into a concrete prediction of\n",
"a particular token.\n",
"\n",
"This is added as the last output token\n",
"and then the loop happens again."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTcy8-rV1dHr"
},
"source": [
"Run this way, our model looks very much like a recurrent architecture:\n",
"we call the model on its own outputs\n",
"to generate the next value.\n",
"These types of models are also referred to as\n",
"[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n",
"because we predict (as we do in _regression_)\n",
"the next value based on our own (_auto_) output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But Transformers are designed to be _trained_ more scalably than RNNs,\n",
"not necessarily to _run inference_ more scalably,\n",
"and it's actually not the case that our model's `.forward` is called during training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCxMSAWmEKBt"
},
"source": [
"Let's look at what happens during training\n",
"by checking the `training_step`\n",
"of the `LightningModule`\n",
"we use to train our Transformer models,\n",
"the `TransformerLitModel`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o7q8N7P2w4H"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import TransformerLitModel\n",
"\n",
"TransformerLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VgNNOjvzC4y"
},
"source": [
"Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tz-6NGPR4dUr"
},
"source": [
"Let's look at `.teacher_forward`,\n",
"and in particular its type signature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ILc2oWET4i2Z"
},
"outputs": [],
"source": [
"TransformerLitModel.teacher_forward??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lf32lpgrDb__"
},
"source": [
"This is known as \"teacher forcing\".\n",
"The \"teacher\" signal is \"forcing\"\n",
"the model to behave as though\n",
"it got the answer right.\n",
"\n",
"[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n",
"It's more effective here\n",
"because the right teaching signal\n",
"for our network is the target data,\n",
"which we have access to during training,\n",
"whereas in an RNN the best teaching signal\n",
"would be the target embedding vector,\n",
"which we do not know.\n",
"\n",
"During inference, when we don't have access to the ground truth,\n",
"we revert to the autoregressive `.forward` method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This \"trick\" allows Transformer architectures to readily scale\n",
"up models to the parameter counts\n",
"[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAjqpJm9uUuU"
},
"source": [
"## Is there more to Transformers more than just a training trick?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kWCYXeHv7Qc9"
},
"source": [
"[Very](https://arxiv.org/abs/2005.14165),\n",
"[very](https://arxiv.org/abs/1909.08053),\n",
"[very](https://arxiv.org/abs/2205.01068)\n",
"large Transformer models have powered the most recent wave of exciting results in ML, like\n",
"[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n",
"\n",
"They are also the first machine learning models to have come anywhere close to\n",
"deserving the term _artificial intelligence_ --\n",
"a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is surprising because the models and their training procedure are\n",
"(relatively speaking)\n",
"pretty _simple_,\n",
"even if it doesn't feel that way on first pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The basic Transformer architecture is just a bunch of\n",
"dense matrix multiplications and non-linearities --\n",
"it's perhaps simpler than a convolutional architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And advances since the introduction of Transformers in 2017\n",
"have not in the main been made by\n",
"creating more sophisticated model architectures\n",
"but by increasing the scale of the base architecture,\n",
"or if anything making it simpler, as in\n",
"[GPT-type models](https://arxiv.org/abs/2005.14165),\n",
"which drop the encoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V1HQS9ey8GMc"
},
"source": [
"These models are also trained on very simple tasks:\n",
"most LLMs are just trying to predict the next element in the sequence,\n",
"given the previous elements --\n",
"a task simple enough that Claude Shannon,\n",
"father of information theory, was\n",
"[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n",
"\n",
"These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n",
"e.g. by scraping the web."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They are also trained in a simple fashion:\n",
"first-order stochastic optimizers, like SGD or an\n",
"[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n",
"intended for the most basic of optimization problems,\n",
"that scale more readily than the second-order optimizers\n",
"that dominate other areas of optimization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kz9HPDoy7OAl"
},
"source": [
"This is\n",
"[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n",
"of work in ML:\n",
"simple, even seemingly wasteful,\n",
"architectures that scale well and are robust\n",
"to implementation details\n",
"eventually outstrip more clever but\n",
"also more finicky approaches that are harder to scale.\n",
"This lesson has led some to declare that\n",
"[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n",
"in machine learning, and perhaps even in artificial intelligence."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdN9o2Y771YZ"
},
"source": [
"> That is not to say that because the algorithms are relatively simple,\n",
" training a model at this scale is _easy_ --\n",
" [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n",
" [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n",
" [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n",
" But choosing the simplest algorithm at every step makes solving the scaling problem feasible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baVGf6gKFOvs"
},
"source": [
"The importance of scale is the key lesson from the Transformer architecture,\n",
"far more than any theoretical considerations\n",
"or any of the implementation details.\n",
"\n",
"That said, these large Transformer models are capable of\n",
"impressive behaviors and understanding how they achieve them\n",
"is of intellectual interest.\n",
"Furthermore, like any architecture,\n",
"there are common failure modes,\n",
"of the model and of the modelers who use them,\n",
"that need to be taken into account."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"Below, we'll cover two key intuitions about Transformers:\n",
"Transformers are _residual_, like ResNets,\n",
"and they compose _low rank_ sequence transformations.\n",
"Together, this means they act somewhat like a computer,\n",
"reading from and writing to a \"tape\" or memory\n",
"with a sequence of simple instructions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"We'll also cover a surprising implementation detail:\n",
"despite being commonly used for sequence modeling,\n",
"by default the architecture is _position insensitive_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uni0VTCr9lev"
},
"source": [
"### Intuition #1: Transformers are highly residual."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0MoBt-JLJz-d"
},
"source": [
"> The discussion of these inuitions summarizes the discussion in\n",
"[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n",
"from\n",
"[Anthropic](https://www.anthropic.com/),\n",
"an AI safety and research company.\n",
"The figures below are from that blog post.\n",
"It is the spiritual successor to the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"covered in\n",
"[Lab 02b](https://lab02b-colab).\n",
"If you want to truly understand Transformers,\n",
"we highly recommend you check it out,\n",
"including the\n",
"[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UUbNVvM5Ferm"
},
"source": [
"It's easy to see that ResNets are residual --\n",
"it's in the name, after all.\n",
"\n",
"But Transformers are,\n",
"in some sense,\n",
"even more closely tied to residual computation\n",
"than are ResNets:\n",
"ResNets and related architectures include downsampling,\n",
"so there is not a direct path from inputs to outputs.\n",
"\n",
"In Transformers, the exact same shape is maintained\n",
"from the moment tokens are embedded,\n",
"through dozens or hundreds of intermediate layers,\n",
"and until they are \"unembedded\" into class logits.\n",
"The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n",
"\n",
"The resiudal stream is easy to see with a change of perspective.\n",
"Instead of the usual architecture diagram above,\n",
"which emphasizes the layers acting on the tensors,\n",
"consider this alternative view,\n",
"which emphasizes the tensors as they pass through the layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRMlVguKKW6y"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-residual-view.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9K3N7ilVkB3"
},
"source": [
"For definitions of variables and terms, see the\n",
"[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arvciE-kKd_L"
},
"source": [
"Note that this is a _decoder-only_ Transformer architecture --\n",
"so it should be compared with the right-hand side of the original architecture diagram above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvrRMd_RKp_G"
},
"source": [
"Notice that outputs of the attention blocks \n",
"and of the MLP layers are\n",
"added to their inputs, as in a ResNet.\n",
"These operations are represented as \"Add & Norm\" layers in the classical diagram;\n",
"normalization is ignored here for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8n_iT-FFAbK"
},
"source": [
"This total commitment to residual operations\n",
"means the size of the embeddings\n",
"(referred to as the \"model dimension\" or the \"embedding dimension\",\n",
"here and below `d_model`)\n",
"stays the same throughout the entire network.\n",
"\n",
"That means, for example,\n",
"that the output of each layer can be used as input to the \"unembedding\" layer\n",
"that produces logits.\n",
"We can read out the computations of intermediate layers\n",
"just by passing them through the unembedding layer\n",
"and examining the logit tensor.\n",
"See\n",
"[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n",
"for detailed experiments and interactive notebooks.\n",
"\n",
"In short, we observe a sort of \"progressive refinement\"\n",
"of the next-token prediction\n",
"as the embeddings proceed, depthwise, through the network."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ovh_3YgY9z2h"
},
"source": [
"### Intuition #2 Transformer heads learn low rank transformations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpNmozlnOdPC"
},
"source": [
"In the original paper and in\n",
"most presentations of Transformers,\n",
"the attention layer is written like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PA7me8gNP5LE"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pseudo-typed PyTorch (based loosely on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n",
"that looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oeict_6wGJgD"
},
"source": [
"```python\n",
"def classic_attention(\n",
" Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" K: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" return torch.softmax(Q @ K.T) @ V\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8pewU90DSuOR"
},
"source": [
"This is effectively exactly\n",
"how it is written\n",
"in PyTorch,\n",
"apart from implementation details\n",
"(look for `bmm` for the matrix multiplications and a `softmax` call):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrgTpKFvOhwc"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"F._scaled_dot_product_attention??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebDXZ0tlSe7g"
},
"source": [
"But the best way to write an operation so that a computer can execute it quickly\n",
"is not necessarily the best way to write it so that a human can understand it --\n",
"otherwise we'd all be coding in assembly.\n",
"\n",
"And this is a strange way to write it --\n",
"you'll notice that what we normally think of\n",
"as the \"inputs\" to the layer are not shown.\n",
"\n",
"We can instead write out the attention layer\n",
"as a function of the inputs $x$.\n",
"We write it for a single \"attention head\".\n",
"Each attention layer includes a number of heads\n",
"that read and write from the residual stream\n",
"simultaneously and independently.\n",
"We also add the output layer weights $W_O$\n",
"and we get:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LuFNR67tQpsf"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVnBjjfOLwxP"
},
"source": [
"or, in pseudo-typed PyTorch:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmpOm-HfGaNz"
},
"source": [
"```python\n",
"def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n",
" key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n",
" key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n",
" # maps queries of residual stream to keys from residual stream, independent of position\n",
"\n",
" value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n",
" output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n",
" value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n",
" # transformation applied to each token, regardless of position\n",
"\n",
" attention_logits = x.T @ key_query_circuit @ x\n",
" attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n",
" # maps positions to positions, often very sparse\n",
"\n",
" value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n",
"\n",
" return attention_map @ value_output # transformed tokens filtered by attention map\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dC0eqxZ6UAGT"
},
"source": [
"Consider the `key_query_circuit`\n",
"and `value_output_circuit`\n",
"matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n",
"\n",
"The key/query dimension, `d_head`\n",
"is small relative to the model's dimension, `d_model`,\n",
"so $W_{QK}$ and $W_{OV}$ are very low rank,\n",
"[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n",
"that they factorize into two matrices,\n",
"one with a smaller number of rows\n",
"and another with a smaller number of columns.\n",
"That number is called the _rank_.\n",
"\n",
"When computing, these matrices are better represented via their components,\n",
"rather than computed directly,\n",
"which leads to the normal implementation of attention.\n",
"\n",
"In a large language model,\n",
"the ratio of residual stream dimension, `d_model`, to\n",
"the dimension of a single head, `d_head`, is huge, often 100:1.\n",
"That means each query, key, and value computed at a position\n",
"is a fairly simple, low-dimensional feature of the residual stream at that position.\n",
"\n",
"For visual intuition,\n",
"we compare what a matrix with a rank 100th of full rank looks like,\n",
"relative to a full rank matrix of the same size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LUbojJMiW2C"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"\n",
"low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n",
"full_rank = torch.randn(100, 100)\n",
"plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n",
"plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqBst92-OVka"
},
"source": [
"The pattern in the first matrix is very simple,\n",
"relative to the pattern in the second matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkCGrs9EiVh4"
},
"source": [
"Another feature of low rank transformations is\n",
"that they have a large nullspace or kernel --\n",
"these are directions we can move the input without changing the output.\n",
"\n",
"That means that many changes to the residual stream won't affect the behavior of this head at all."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVz2dQgzhD4p"
},
"source": [
"### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVlzwR03m8mC"
},
"source": [
"The combination of residuality\n",
"(changes are added to the current input)\n",
"and low rank\n",
"(only a small subspace is changed by each head)\n",
"drastically changes the intuition about Transformers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qqjZI2jKe6HH"
},
"source": [
"Rather than being an \"embedding of a token in its context\",\n",
"the residual stream becomes something more like a memory or a scratchpad:\n",
"one layer reads a small bit of information from the stream\n",
"and writes a small bit of information back to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YIBkxlqepjc"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-layer-residual.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RtsKhkLfk00l"
},
"source": [
"The residual stream works like a memory because it is roomy enough\n",
"that these actions need not interfere:\n",
"the subspaces targeted by reads and writes are small relative to the ambient space,\n",
"so they can\n",
"\n",
"Additionally, the dimension of each head is still in the 100s in large models,\n",
"and\n",
"[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n",
"in them, so the number of effectively degrees of freedom is\n",
"actually larger than the dimension.\n",
"This phenomenon allows high-dimensional tensors to serve as\n",
"[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n",
"There are\n",
"[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n",
"\n",
"Together, this means an early layer can write information to the stream\n",
"that can be used by later layers -- by many of them at once, possibly much later.\n",
"Later layers can learn to edit this information,\n",
"e.g. deleting it,\n",
"if doing so reduces the loss,\n",
"but by default the information is preserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EragIygzJg86"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-stream-read-write.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKIaUZjwkpW7"
},
"source": [
"Lastly, the softmax in the attention has a sparsifying effect,\n",
"and so many attention heads are reading from \n",
"just one token and writing to just one other token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dN6VcJqIMKnB"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-token-to-token.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Repeatedly reading information from an external memory\n",
"and using it to decide which operation to perform\n",
"and where to write the results\n",
"is at the core of the\n",
"[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n",
"For a concrete example, the\n",
"[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n",
"includes a dissection of a form of \"pointer arithmetic\"\n",
"that appears in some models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kLFh7Mvnolr"
},
"source": [
"This point of view seems\n",
"very promising for explaining numerous\n",
"otherwise perhaps counterintuitive features of Transformer models.\n",
"\n",
"- This framework predicts lots that Transformers will readily copy-and-paste information,\n",
"which might explain phenomena like\n",
"[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n",
"\n",
"- It also readily explains\n",
"[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n",
"an important component of why Transformers perform well on medium-length texts\n",
"and in few-shot learning.\n",
"\n",
"- Transformers also perform better on reasoning tasks when the text\n",
"[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n",
"is added to their input prompt.\n",
"This is partly due to the fact that that prompt is associated,\n",
"in the dataset, with clearer reasoning,\n",
"and since the models are trained to predict which tokens tend to appear\n",
"after an input, they tend to produce better reasoning with that prompt --\n",
"an explanation purely in terms of sequence modeling.\n",
"But it also gives the Transformer license to generate a large number of tokens\n",
"that act to store intermediate information,\n",
"making for a richer residual stream\n",
"for reading and writing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RyLRzgG-93yB"
},
"source": [
"### Implementation detail: Transformers are position-insensitive by default."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oR6PnrlA_hJ2"
},
"source": [
"In the attention calculation\n",
"each token can query each other token,\n",
"with no regard for order.\n",
"Furthermore, the construction of queries, keys, and values\n",
"is based on the content of the embedding vector,\n",
"which does not automatically include its position.\n",
"\"dog bites man\" and \"man bites dog\" are identical, as in\n",
"[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n",
"\n",
"For most sequences,\n",
"this is unacceptable:\n",
"absolute and relative position matter\n",
"and we cannot use the future to predict the past.\n",
"\n",
"We need to add two pieces to get a Transformer architecture that's usable for next-token prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWHxGJz2-6ZK"
},
"source": [
"First, the simpler piece:\n",
"\"causal\" attention,\n",
"so-named because it ensures that values earlier in the sequence\n",
"are not influenced by later values, which would\n",
"[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0c42xi6URYB4"
},
"source": [
"The most common solution is straightforward:\n",
"we calculate attention between all tokens,\n",
"then throw out non-causal values by \"masking\" them\n",
"(this is before applying the softmax,\n",
"so masking means adding $-\\infty$).\n",
"\n",
"This feels wasteful --\n",
"why are we calculating values we don't need?\n",
"Trying to be smarter would be harder,\n",
"and might rely on operations that aren't as optimized as\n",
"matrix multiplication and addition.\n",
"Furthermore, it's \"only\" twice as many operations,\n",
"so it doesn't even show up in $O$-notation.\n",
"\n",
"A sample attention mask generated by our code base is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NXaWe6pT-9jV"
},
"outputs": [],
"source": [
"from text_recognizer.models import transformer_util\n",
"\n",
"\n",
"attention_mask = transformer_util.generate_square_subsequent_mask(100)\n",
"\n",
"ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n",
"plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n",
"print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This solves our causality problem,\n",
"but we still don't have positional information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZamUE4WIoGS2"
},
"source": [
"The standard technique\n",
"is to add alternating sines and cosines\n",
"of increasing frequency to the embeddings\n",
"(there are\n",
"[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n",
"most notably\n",
"[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n",
"Each position in the sequence is then uniquely identifiable\n",
"from the pattern of these values.\n",
"\n",
"> Furthermore, for the same reason that\n",
" [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n",
" translations, e.g. relative positions, are fairly easy to express as linear transformations\n",
" of sines and cosines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IDG2uOsaELU0"
},
"source": [
"We superimpose this positional information on our embeddings.\n",
"Note that because the model is residual,\n",
"this position information will be by default preserved\n",
"as it passes through the network,\n",
"so it doesn't need to be repeatedly added."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what this positional encoding looks like in our codebase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Zk62Q-a-1Ax"
},
"outputs": [],
"source": [
"PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n",
"\n",
"pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n",
"\n",
"ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n",
"print(pe[:4, :8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep2ClIWvqDms"
},
"source": [
"When we add the positional information to our embeddings,\n",
"both the embedding information and the positional information\n",
"is approximately preserved,\n",
"as can be visually assessed below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJuFjoCzC0Y4"
},
"outputs": [],
"source": [
"fake_embeddings = torch.randn_like(pe) * 0.5\n",
"\n",
"ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n",
"\n",
"fake_embeddings_with_pe = fake_embeddings + pe\n",
"\n",
"plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHIzBxDkEmH8"
},
"source": [
"A [similar technique](https://arxiv.org/abs/2103.06450)\n",
"is used to also incorporate positional information into the image embeddings,\n",
"which are flattened before being fed to the decoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HC1N85wl8dvn"
},
"source": [
"### Learn more about Transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lJwYxkjTk15t"
},
"source": [
"We're only able to give a flavor and an intuition for Transformers here.\n",
"\n",
"To improve your grasp on the nuts and bolts, check out the\n",
"[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n",
"which is surprisingly approachable,\n",
"as far as ML research papers go.\n",
"The\n",
"[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n",
"adds code and commentary to the original paper,\n",
"which makes it even more digestible.\n",
"For something even friendlier, check out the\n",
"[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n",
"by Jay Alammar, which has an accompanying\n",
"[video](https://youtu.be/-QH8fRhqFHM).\n",
"\n",
"Anthropic's work on\n",
"[Transformer Circuits](https://transformer-circuits.pub/),\n",
"summarized above, has some of the best material\n",
"for building theoretical understanding\n",
"and is still being updated with extensions and applications of the framework.\n",
"The\n",
"[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n",
"are a great aid for checking and building your understanding.\n",
"\n",
"But they are fairly math-heavy.\n",
"If you have more of a software engineering background, see\n",
"Transformer Circuits co-author Nelson Elhage's blog post\n",
"[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n",
"\n",
"For a gentler introduction to the intuition for Transformers,\n",
"check out Brandon Rohrer's\n",
"[Transformers From Scratch](https://e2eml.school/transformers.html)\n",
"tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qg7zntJES-aT"
},
"source": [
"An aside:\n",
"the matrix multiplications inside attention dominate\n",
"the big-$O$ runtime of Transformers.\n",
"So trying to make the attention mechanism more efficient, e.g. linear time,\n",
"has generated a lot of research\n",
"(review paper\n",
"[here](https://arxiv.org/abs/2009.06732)).\n",
"Despite drawing a lot of attention, so to speak,\n",
"at the time of writing in mid-2022, these methods\n",
"[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n",
"so it isn't likely to be worth the effort to spend time learning about them\n",
"unless you are a Transformer specialist."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCjXysEJ8g9_"
},
"source": [
"# Using Transformers to read paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsfKWnOvqjva"
},
"source": [
"Our simple convolutional model for text recognition from\n",
"[Lab 02b](https://fsdl.me/lab02b-colab)\n",
"could only handle cleanly-separated characters.\n",
"\n",
"It worked by sliding a LeNet-style CNN\n",
"over the image,\n",
"predicting a character for each step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "njLdzBqy-I90"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist_lines = text_recognizer.data.EMNISTLines()\n",
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"\n",
"# for sliding, see the for loop over range(S)\n",
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0N6yDBQq8ns"
},
"source": [
"But unfortunately for us, handwritten text\n",
"doesn't come in neatly-separated characters\n",
"of equal size, so we trained our model on synthetic data\n",
"designed to work with that model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiqUVbj0sxLr"
},
"source": [
"Now that we have a better model,\n",
"we can work with better data:\n",
"paragraphs from the\n",
"[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oizsOAcKs-dD"
},
"source": [
"The cell uses our `LightningDataModule`\n",
"to download and preprocess this data,\n",
"writing results to disk.\n",
"We can then spin up `DataLoader`s to give us batches.\n",
"\n",
"It can take several minutes to run the first time\n",
"on commodity machines,\n",
"with most time spent extracting the data.\n",
"On subsequent runs,\n",
"the time-consuming operations will not be repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uL9LHbjdsUbm"
},
"outputs": [],
"source": [
"iam_paragraphs = text_recognizer.data.IAMParagraphs()\n",
"\n",
"iam_paragraphs.prepare_data()\n",
"iam_paragraphs.setup()\n",
"xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n",
"\n",
"iam_paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBkFN9bbTm_S"
},
"source": [
"Now that we've got a batch,\n",
"let's take a look at some samples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hqaps8yxtBhU"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import wandb\n",
"\n",
"\n",
"def show(y):\n",
" y = y.detach().cpu() # bring back from accelerator if it's being used\n",
" return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"
", "", " and ", *tokens, " and ", *tokens, ""]
self.end_index = self.inverse_mapping["",
""]
self.end_token = inverse_mapping[""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab04/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab04/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 03: Transformers and Paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The fundamental reasons why the Transformer is such\n",
"a powerful and popular architecture\n",
"- Core intuitions for the behavior of Transformer architectures\n",
"- How to use a convolutional encoder and a Transformer decoder to recognize\n",
"entire paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 3\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Transformers?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal in building a text recognizer is to take a two-dimensional image\n",
"and convert it into a one-dimensional sequence of characters\n",
"from some alphabet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional neural networks,\n",
"discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n",
"are great at encoding images,\n",
"taking them from their raw pixel values\n",
"to a more semantically meaningful numerical representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But how do we go from that to a sequence of letters?\n",
"And what's especially tricky:\n",
"the number of letters in an image is separable from its size.\n",
"A screenshot of this document has a much higher density of letters\n",
"than a close-up photograph of a piece of paper.\n",
"How do we get a _variable-length_ sequence of letters,\n",
"where the length need have nothing to do with the size of the input tensor?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n",
"they were\n",
"[originally introduced](https://arxiv.org/abs/1706.03762)\n",
"for transforming one sequence into another,\n",
"as in machine translation.\n",
"This makes them a natural fit for processing language.\n",
"\n",
"But they have also found success in other domains --\n",
"at the time of this writing, large transformers\n",
"dominate the\n",
"[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n",
"that has become a de facto standard for comparing models\n",
"and are finding\n",
"[application in reinforcement learning](https://arxiv.org/abs/2106.01345)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we will use a Transformer as a key component of our final architecture:\n",
"we will encode our input images with a CNN\n",
"and then read them out into a text sequence with a Transformer.\n",
"\n",
"Before trying out this new model,\n",
"let's first get an understanding of why the Transformer architecture\n",
"has become so popular by walking through its history\n",
"and then get some intuition for how it works\n",
"by looking at some\n",
"[recent work](https://transformer-circuits.pub/)\n",
"on explaining the behavior of both toy models and state-of-the-art language models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kmKqjbvd-Mj3"
},
"source": [
"## Why not convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRqkUMdM-OxU"
},
"source": [
"In the ancient beforetimes (i.e. 2016),\n",
"the best models for natural language processing were all\n",
"_recurrent_ neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional networks were also occasionally used,\n",
"but they suffered from a serious issue:\n",
"their architectural biases don't fit text.\n",
"\n",
"First, _translation equivariance_ no longer holds.\n",
"The beginning of a piece of text is often quite different from the middle,\n",
"so the absolute position matters.\n",
"\n",
"Second, _locality_ is not as important in language.\n",
"The name of a character that hasn't appeared in thousands of pages\n",
"can become salient when someone asks, \"Whatever happened to\n",
"[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n",
"\n",
"Consider interpreting a piece of text like the Python code below:\n",
"```python\n",
"def do(arg1, arg2, arg3):\n",
" a = arg1 + arg2\n",
" b = arg3[:3]\n",
" c = a * b\n",
" return c\n",
"\n",
"print(do(1, 1, \"ayy lmao\"))\n",
"```\n",
"\n",
"After a `(` we expect a `)`,\n",
"but possibly very long afterwards,\n",
"[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n",
"and similarly we expect a `]` at some point after a `[`.\n",
"\n",
"For translation variance, consider\n",
"that we interpret `*` not by\n",
"comparing it to its neighbors\n",
"but by looking at `a` and `b`.\n",
"We mix knowledge learned through experience\n",
"with new facts learned while reading --\n",
"also known as _in-context learning_.\n",
"\n",
"In a longer text,\n",
"[e.g. the one you are reading now](./lab03_transformers.ipynb),\n",
"the translation variance of text is clearer.\n",
"Every lab notebook begins with the same header,\n",
"setting up the environment,\n",
"but that header never appears elsewhere in the notebook.\n",
"Later positions need to be processed in terms of the previous entries.\n",
"\n",
"Unlike an image, we cannot simply rotate or translate our \"camera\"\n",
"and get a new valid text.\n",
"[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n",
"that can be read without regard to position."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The field of formal language theory,\n",
"which has deep mutual influence with computer science,\n",
"gives one way of explaining the issues with convolutional networks:\n",
"they can only understand languages with _finite contexts_,\n",
"where all the information can be found within a finite window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The immediate solution, drawing from the connections to computer science, is\n",
"[recursion](https://www.google.com/search?q=recursion).\n",
"A network whose output on the final entry of the sequence is a recursive function\n",
"of all the previous entries can build up knowledge\n",
"as it reads the sequence and treat early entries quite differently than it does late ones."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa6cbTlImkEh"
},
"source": [
"In pseudo-code, such a _recurrent neural network_ module might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lKtBoPnglPrW"
},
"source": [
"```python\n",
"def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n",
" next_inputs = input_module(xs[-1])\n",
" next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n",
" return output_module(next_inputs, next_hiddens)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbJPSMnEm516"
},
"source": [
"If you've had formal computer science training,\n",
"then you may be familiar with the power of recursion,\n",
"e.g. the\n",
"[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n",
"that gave its name to the now much better-known\n",
"[startup incubator](https://www.ycombinator.com/).\n",
"\n",
"The particular form of recursion used by\n",
"recurrent neural networks implements a\n",
"[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n",
"\n",
"> If you've know a lot of computer science,\n",
"you might be concerned by this connection.\n",
"What about other\n",
"[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n",
"Where are the neural network architectures for differentiable\n",
"[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n",
"Check out Graph Neural Networks,\n",
"[which implement dynamic programming](https://arxiv.org/abs/2203.15544)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "63mMTbEBpVuE"
},
"source": [
"Recurrent networks are able to achieve\n",
"[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n",
"\n",
"There are many popular recurrent architectures,\n",
"from the beefy and classic\n",
"[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n",
"and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n",
"([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n",
"all of which have roughly similar capabilities but\n",
"[some of which are easier to train](https://arxiv.org/abs/1611.09913)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwQHVTIslOku"
},
"source": [
"In the same sense that MLPs can model \"any\" feedforward function,\n",
"in principle even basic RNNs\n",
"[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n",
"\n",
"In particular they can model any\n",
"[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n",
"which is a formal way of saying that they can in principle\n",
"do anything a computer is capable of doing.\n",
"\n",
"The question is then..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3J8EoGN3pu7P"
},
"source": [
"## Why aren't we all using RNNs?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDwNWaevpt_3"
},
"source": [
"The guarantees that MLPs can model any function\n",
"or that RNNs can model Turing machines\n",
"provide decent intuition but are not directly practically useful.\n",
"Among other reasons, they don't guarantee learnability --\n",
"that starting from random parameters we can find the parameters\n",
"that implement a given function.\n",
"The\n",
"[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n",
"than would seem from basic theoretical and empirical analysis.\n",
"\n",
"One way of understanding capacity to model language is\n",
"[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n",
"In this model of formal languages,\n",
"Turing machines sit at the top\n",
"([practically speaking](https://arxiv.org/abs/math/0209332)).\n",
"\n",
"With better mathematical models,\n",
"RNNs and LSTMs can be shown to be\n",
"[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n",
"with RNNs looking more like\n",
"[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n",
"and LSTMs coming in\n",
"[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n",
"\n",
"More controversially:\n",
"the Chomsky hierarchy is great for understanding syntax and grammar,\n",
"which makes it great for building parsers\n",
"and working with formal languages,\n",
"but the goal in _natural_ language processing is to understand _natural_ language.\n",
"Most humans' natural language is far from strictly grammatical,\n",
"but that doesn't mean it is nonsense.\n",
"\n",
"And to really \"understand\" language means\n",
"to understand its semantic content, which is fuzzy.\n",
"The most important thing for handling the fuzzy semantic content\n",
"of language is not whether you can recall\n",
"[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n",
"but whether you can model probabilistic relationships between concepts\n",
"in addition to grammar and syntax."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These both leave theoretical room for improvement over current recurrent\n",
"language and sequence models.\n",
"\n",
"But the real cause of the rise of Transformers is that..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dsu1ebvAp-3Z"
},
"source": [
"## Transformers are designed to train fast at scale on contemporary hardware."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4abU5adsPGs"
},
"source": [
"The Transformer architecture has several important features,\n",
"discussed below,\n",
"but one of the most important reasons why it is successful\n",
"is because it can be more easily trained at scale.\n",
"\n",
"This scalability is the focus of the discussion in the paper\n",
"that introduced the architecture,\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"and\n",
"[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n",
"\n",
"The recursion in RNNs is inherently sequential:\n",
"the dependence on the outputs from earlier in the sequence\n",
"means computations within an example cannot be parallelized.\n",
"\n",
"So RNNs must batch across examples to scale,\n",
"but as sequence length grows this hits memorybandwidth limits.\n",
"Serving up large batches quickly with good randomness guarantees\n",
"is also hard to optimize,\n",
"especially in distributed settings.\n",
"\n",
"The Transformer architecture,\n",
"on the other hand,\n",
"can be readily parallelized within a single example sequence,\n",
"in addition to parallelization across batches.\n",
"This can lead to massive performance gains for a fixed scale,\n",
"which means larger, higher capacity models\n",
"can be trained on larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Mzk2haFC_G1"
},
"source": [
"How does the architecture achieve this parallelizability?\n",
"\n",
"Let's start with the architecture diagram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u59eu4snLQfp"
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n",
"\n",
"display.Image(url=base_url + \"/aiayn-figure-1.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ez-XEQ7M0UlR"
},
"source": [
"> To head off a bit of confusion\n",
" in case you've worked with Transformer architectures before:\n",
" the original \"Transformer\" is an encoder/decoder architecture.\n",
" Many LLMs, like GPT models, are decoder only,\n",
" because this has turned out to scale well,\n",
" and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n",
" it's all text anyways.\n",
" We, however, will be using them across modalities,\n",
" so we need an explicit encoder,\n",
" as above. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ok4ksBi4vp89"
},
"source": [
"First focusing on the encoder (left):\n",
"the encoding at a given position is a function of all previous inputs.\n",
"But it is not a function of the previous _encodings_:\n",
"we produce the encodings \"all at once\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RPN7C-_OqzHP"
},
"source": [
"The decoder (right) does use previous \"outputs\" as its inputs,\n",
"but those outputs are not the vectors of layer activations\n",
"(aka embeddings)\n",
"that are produced by the network.\n",
"They are instead the processed outputs,\n",
"after a `softmax` and an `argmax`.\n",
"\n",
"We could obtain these outputs by processing the embeddings,\n",
"much like in a recurrent architecture.\n",
"In fact, that is one way that Transformers are run.\n",
"It's what happens in the `.forward` method\n",
"of the model we'll be training for character recognition:\n",
"`ResnetTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L5_2WMmtDnJn"
},
"source": [
"Let's look at that forward method\n",
"and connect it to the diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FR5pk4kEyCGg"
},
"outputs": [],
"source": [
"from text_recognizer.models import ResnetTransformer\n",
"\n",
"\n",
"ResnetTransformer.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-J5UFDoPzPbq"
},
"source": [
"`.encode` happens first -- that's the left side of diagram.\n",
"\n",
"The encoder can in principle be anything\n",
"that produces a sequence of fixed-length vectors,\n",
"but here it's\n",
"[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n",
"\n",
"Then we start iterating over the sequence\n",
"in the `for` loop.\n",
"\n",
"Focus on the first few lines of code.\n",
"We apply `.decode` (right side of diagram)\n",
"to the outputs so far.\n",
"\n",
"Once we have a new `output`, we apply `.argmax`\n",
"to turn the logits into a concrete prediction of\n",
"a particular token.\n",
"\n",
"This is added as the last output token\n",
"and then the loop happens again."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTcy8-rV1dHr"
},
"source": [
"Run this way, our model looks very much like a recurrent architecture:\n",
"we call the model on its own outputs\n",
"to generate the next value.\n",
"These types of models are also referred to as\n",
"[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n",
"because we predict (as we do in _regression_)\n",
"the next value based on our own (_auto_) output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But Transformers are designed to be _trained_ more scalably than RNNs,\n",
"not necessarily to _run inference_ more scalably,\n",
"and it's actually not the case that our model's `.forward` is called during training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCxMSAWmEKBt"
},
"source": [
"Let's look at what happens during training\n",
"by checking the `training_step`\n",
"of the `LightningModule`\n",
"we use to train our Transformer models,\n",
"the `TransformerLitModel`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o7q8N7P2w4H"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import TransformerLitModel\n",
"\n",
"TransformerLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VgNNOjvzC4y"
},
"source": [
"Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tz-6NGPR4dUr"
},
"source": [
"Let's look at `.teacher_forward`,\n",
"and in particular its type signature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ILc2oWET4i2Z"
},
"outputs": [],
"source": [
"TransformerLitModel.teacher_forward??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lf32lpgrDb__"
},
"source": [
"This is known as \"teacher forcing\".\n",
"The \"teacher\" signal is \"forcing\"\n",
"the model to behave as though\n",
"it got the answer right.\n",
"\n",
"[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n",
"It's more effective here\n",
"because the right teaching signal\n",
"for our network is the target data,\n",
"which we have access to during training,\n",
"whereas in an RNN the best teaching signal\n",
"would be the target embedding vector,\n",
"which we do not know.\n",
"\n",
"During inference, when we don't have access to the ground truth,\n",
"we revert to the autoregressive `.forward` method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This \"trick\" allows Transformer architectures to readily scale\n",
"up models to the parameter counts\n",
"[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAjqpJm9uUuU"
},
"source": [
"## Is there more to Transformers more than just a training trick?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kWCYXeHv7Qc9"
},
"source": [
"[Very](https://arxiv.org/abs/2005.14165),\n",
"[very](https://arxiv.org/abs/1909.08053),\n",
"[very](https://arxiv.org/abs/2205.01068)\n",
"large Transformer models have powered the most recent wave of exciting results in ML, like\n",
"[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n",
"\n",
"They are also the first machine learning models to have come anywhere close to\n",
"deserving the term _artificial intelligence_ --\n",
"a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is surprising because the models and their training procedure are\n",
"(relatively speaking)\n",
"pretty _simple_,\n",
"even if it doesn't feel that way on first pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The basic Transformer architecture is just a bunch of\n",
"dense matrix multiplications and non-linearities --\n",
"it's perhaps simpler than a convolutional architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And advances since the introduction of Transformers in 2017\n",
"have not in the main been made by\n",
"creating more sophisticated model architectures\n",
"but by increasing the scale of the base architecture,\n",
"or if anything making it simpler, as in\n",
"[GPT-type models](https://arxiv.org/abs/2005.14165),\n",
"which drop the encoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V1HQS9ey8GMc"
},
"source": [
"These models are also trained on very simple tasks:\n",
"most LLMs are just trying to predict the next element in the sequence,\n",
"given the previous elements --\n",
"a task simple enough that Claude Shannon,\n",
"father of information theory, was\n",
"[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n",
"\n",
"These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n",
"e.g. by scraping the web."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They are also trained in a simple fashion:\n",
"first-order stochastic optimizers, like SGD or an\n",
"[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n",
"intended for the most basic of optimization problems,\n",
"that scale more readily than the second-order optimizers\n",
"that dominate other areas of optimization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kz9HPDoy7OAl"
},
"source": [
"This is\n",
"[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n",
"of work in ML:\n",
"simple, even seemingly wasteful,\n",
"architectures that scale well and are robust\n",
"to implementation details\n",
"eventually outstrip more clever but\n",
"also more finicky approaches that are harder to scale.\n",
"This lesson has led some to declare that\n",
"[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n",
"in machine learning, and perhaps even in artificial intelligence."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdN9o2Y771YZ"
},
"source": [
"> That is not to say that because the algorithms are relatively simple,\n",
" training a model at this scale is _easy_ --\n",
" [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n",
" [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n",
" [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n",
" But choosing the simplest algorithm at every step makes solving the scaling problem feasible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baVGf6gKFOvs"
},
"source": [
"The importance of scale is the key lesson from the Transformer architecture,\n",
"far more than any theoretical considerations\n",
"or any of the implementation details.\n",
"\n",
"That said, these large Transformer models are capable of\n",
"impressive behaviors and understanding how they achieve them\n",
"is of intellectual interest.\n",
"Furthermore, like any architecture,\n",
"there are common failure modes,\n",
"of the model and of the modelers who use them,\n",
"that need to be taken into account."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"Below, we'll cover two key intuitions about Transformers:\n",
"Transformers are _residual_, like ResNets,\n",
"and they compose _low rank_ sequence transformations.\n",
"Together, this means they act somewhat like a computer,\n",
"reading from and writing to a \"tape\" or memory\n",
"with a sequence of simple instructions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"We'll also cover a surprising implementation detail:\n",
"despite being commonly used for sequence modeling,\n",
"by default the architecture is _position insensitive_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uni0VTCr9lev"
},
"source": [
"### Intuition #1: Transformers are highly residual."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0MoBt-JLJz-d"
},
"source": [
"> The discussion of these inuitions summarizes the discussion in\n",
"[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n",
"from\n",
"[Anthropic](https://www.anthropic.com/),\n",
"an AI safety and research company.\n",
"The figures below are from that blog post.\n",
"It is the spiritual successor to the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"covered in\n",
"[Lab 02b](https://lab02b-colab).\n",
"If you want to truly understand Transformers,\n",
"we highly recommend you check it out,\n",
"including the\n",
"[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UUbNVvM5Ferm"
},
"source": [
"It's easy to see that ResNets are residual --\n",
"it's in the name, after all.\n",
"\n",
"But Transformers are,\n",
"in some sense,\n",
"even more closely tied to residual computation\n",
"than are ResNets:\n",
"ResNets and related architectures include downsampling,\n",
"so there is not a direct path from inputs to outputs.\n",
"\n",
"In Transformers, the exact same shape is maintained\n",
"from the moment tokens are embedded,\n",
"through dozens or hundreds of intermediate layers,\n",
"and until they are \"unembedded\" into class logits.\n",
"The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n",
"\n",
"The resiudal stream is easy to see with a change of perspective.\n",
"Instead of the usual architecture diagram above,\n",
"which emphasizes the layers acting on the tensors,\n",
"consider this alternative view,\n",
"which emphasizes the tensors as they pass through the layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRMlVguKKW6y"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-residual-view.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9K3N7ilVkB3"
},
"source": [
"For definitions of variables and terms, see the\n",
"[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arvciE-kKd_L"
},
"source": [
"Note that this is a _decoder-only_ Transformer architecture --\n",
"so it should be compared with the right-hand side of the original architecture diagram above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvrRMd_RKp_G"
},
"source": [
"Notice that outputs of the attention blocks \n",
"and of the MLP layers are\n",
"added to their inputs, as in a ResNet.\n",
"These operations are represented as \"Add & Norm\" layers in the classical diagram;\n",
"normalization is ignored here for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8n_iT-FFAbK"
},
"source": [
"This total commitment to residual operations\n",
"means the size of the embeddings\n",
"(referred to as the \"model dimension\" or the \"embedding dimension\",\n",
"here and below `d_model`)\n",
"stays the same throughout the entire network.\n",
"\n",
"That means, for example,\n",
"that the output of each layer can be used as input to the \"unembedding\" layer\n",
"that produces logits.\n",
"We can read out the computations of intermediate layers\n",
"just by passing them through the unembedding layer\n",
"and examining the logit tensor.\n",
"See\n",
"[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n",
"for detailed experiments and interactive notebooks.\n",
"\n",
"In short, we observe a sort of \"progressive refinement\"\n",
"of the next-token prediction\n",
"as the embeddings proceed, depthwise, through the network."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ovh_3YgY9z2h"
},
"source": [
"### Intuition #2 Transformer heads learn low rank transformations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpNmozlnOdPC"
},
"source": [
"In the original paper and in\n",
"most presentations of Transformers,\n",
"the attention layer is written like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PA7me8gNP5LE"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pseudo-typed PyTorch (based loosely on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n",
"that looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oeict_6wGJgD"
},
"source": [
"```python\n",
"def classic_attention(\n",
" Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" K: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" return torch.softmax(Q @ K.T) @ V\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8pewU90DSuOR"
},
"source": [
"This is effectively exactly\n",
"how it is written\n",
"in PyTorch,\n",
"apart from implementation details\n",
"(look for `bmm` for the matrix multiplications and a `softmax` call):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrgTpKFvOhwc"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"F._scaled_dot_product_attention??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebDXZ0tlSe7g"
},
"source": [
"But the best way to write an operation so that a computer can execute it quickly\n",
"is not necessarily the best way to write it so that a human can understand it --\n",
"otherwise we'd all be coding in assembly.\n",
"\n",
"And this is a strange way to write it --\n",
"you'll notice that what we normally think of\n",
"as the \"inputs\" to the layer are not shown.\n",
"\n",
"We can instead write out the attention layer\n",
"as a function of the inputs $x$.\n",
"We write it for a single \"attention head\".\n",
"Each attention layer includes a number of heads\n",
"that read and write from the residual stream\n",
"simultaneously and independently.\n",
"We also add the output layer weights $W_O$\n",
"and we get:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LuFNR67tQpsf"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVnBjjfOLwxP"
},
"source": [
"or, in pseudo-typed PyTorch:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmpOm-HfGaNz"
},
"source": [
"```python\n",
"def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n",
" key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n",
" key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n",
" # maps queries of residual stream to keys from residual stream, independent of position\n",
"\n",
" value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n",
" output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n",
" value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n",
" # transformation applied to each token, regardless of position\n",
"\n",
" attention_logits = x.T @ key_query_circuit @ x\n",
" attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n",
" # maps positions to positions, often very sparse\n",
"\n",
" value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n",
"\n",
" return attention_map @ value_output # transformed tokens filtered by attention map\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dC0eqxZ6UAGT"
},
"source": [
"Consider the `key_query_circuit`\n",
"and `value_output_circuit`\n",
"matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n",
"\n",
"The key/query dimension, `d_head`\n",
"is small relative to the model's dimension, `d_model`,\n",
"so $W_{QK}$ and $W_{OV}$ are very low rank,\n",
"[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n",
"that they factorize into two matrices,\n",
"one with a smaller number of rows\n",
"and another with a smaller number of columns.\n",
"That number is called the _rank_.\n",
"\n",
"When computing, these matrices are better represented via their components,\n",
"rather than computed directly,\n",
"which leads to the normal implementation of attention.\n",
"\n",
"In a large language model,\n",
"the ratio of residual stream dimension, `d_model`, to\n",
"the dimension of a single head, `d_head`, is huge, often 100:1.\n",
"That means each query, key, and value computed at a position\n",
"is a fairly simple, low-dimensional feature of the residual stream at that position.\n",
"\n",
"For visual intuition,\n",
"we compare what a matrix with a rank 100th of full rank looks like,\n",
"relative to a full rank matrix of the same size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LUbojJMiW2C"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"\n",
"low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n",
"full_rank = torch.randn(100, 100)\n",
"plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n",
"plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqBst92-OVka"
},
"source": [
"The pattern in the first matrix is very simple,\n",
"relative to the pattern in the second matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkCGrs9EiVh4"
},
"source": [
"Another feature of low rank transformations is\n",
"that they have a large nullspace or kernel --\n",
"these are directions we can move the input without changing the output.\n",
"\n",
"That means that many changes to the residual stream won't affect the behavior of this head at all."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVz2dQgzhD4p"
},
"source": [
"### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVlzwR03m8mC"
},
"source": [
"The combination of residuality\n",
"(changes are added to the current input)\n",
"and low rank\n",
"(only a small subspace is changed by each head)\n",
"drastically changes the intuition about Transformers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qqjZI2jKe6HH"
},
"source": [
"Rather than being an \"embedding of a token in its context\",\n",
"the residual stream becomes something more like a memory or a scratchpad:\n",
"one layer reads a small bit of information from the stream\n",
"and writes a small bit of information back to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YIBkxlqepjc"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-layer-residual.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RtsKhkLfk00l"
},
"source": [
"The residual stream works like a memory because it is roomy enough\n",
"that these actions need not interfere:\n",
"the subspaces targeted by reads and writes are small relative to the ambient space,\n",
"so they can\n",
"\n",
"Additionally, the dimension of each head is still in the 100s in large models,\n",
"and\n",
"[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n",
"in them, so the number of effectively degrees of freedom is\n",
"actually larger than the dimension.\n",
"This phenomenon allows high-dimensional tensors to serve as\n",
"[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n",
"There are\n",
"[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n",
"\n",
"Together, this means an early layer can write information to the stream\n",
"that can be used by later layers -- by many of them at once, possibly much later.\n",
"Later layers can learn to edit this information,\n",
"e.g. deleting it,\n",
"if doing so reduces the loss,\n",
"but by default the information is preserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EragIygzJg86"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-stream-read-write.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKIaUZjwkpW7"
},
"source": [
"Lastly, the softmax in the attention has a sparsifying effect,\n",
"and so many attention heads are reading from \n",
"just one token and writing to just one other token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dN6VcJqIMKnB"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-token-to-token.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Repeatedly reading information from an external memory\n",
"and using it to decide which operation to perform\n",
"and where to write the results\n",
"is at the core of the\n",
"[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n",
"For a concrete example, the\n",
"[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n",
"includes a dissection of a form of \"pointer arithmetic\"\n",
"that appears in some models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kLFh7Mvnolr"
},
"source": [
"This point of view seems\n",
"very promising for explaining numerous\n",
"otherwise perhaps counterintuitive features of Transformer models.\n",
"\n",
"- This framework predicts lots that Transformers will readily copy-and-paste information,\n",
"which might explain phenomena like\n",
"[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n",
"\n",
"- It also readily explains\n",
"[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n",
"an important component of why Transformers perform well on medium-length texts\n",
"and in few-shot learning.\n",
"\n",
"- Transformers also perform better on reasoning tasks when the text\n",
"[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n",
"is added to their input prompt.\n",
"This is partly due to the fact that that prompt is associated,\n",
"in the dataset, with clearer reasoning,\n",
"and since the models are trained to predict which tokens tend to appear\n",
"after an input, they tend to produce better reasoning with that prompt --\n",
"an explanation purely in terms of sequence modeling.\n",
"But it also gives the Transformer license to generate a large number of tokens\n",
"that act to store intermediate information,\n",
"making for a richer residual stream\n",
"for reading and writing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RyLRzgG-93yB"
},
"source": [
"### Implementation detail: Transformers are position-insensitive by default."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oR6PnrlA_hJ2"
},
"source": [
"In the attention calculation\n",
"each token can query each other token,\n",
"with no regard for order.\n",
"Furthermore, the construction of queries, keys, and values\n",
"is based on the content of the embedding vector,\n",
"which does not automatically include its position.\n",
"\"dog bites man\" and \"man bites dog\" are identical, as in\n",
"[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n",
"\n",
"For most sequences,\n",
"this is unacceptable:\n",
"absolute and relative position matter\n",
"and we cannot use the future to predict the past.\n",
"\n",
"We need to add two pieces to get a Transformer architecture that's usable for next-token prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWHxGJz2-6ZK"
},
"source": [
"First, the simpler piece:\n",
"\"causal\" attention,\n",
"so-named because it ensures that values earlier in the sequence\n",
"are not influenced by later values, which would\n",
"[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0c42xi6URYB4"
},
"source": [
"The most common solution is straightforward:\n",
"we calculate attention between all tokens,\n",
"then throw out non-causal values by \"masking\" them\n",
"(this is before applying the softmax,\n",
"so masking means adding $-\\infty$).\n",
"\n",
"This feels wasteful --\n",
"why are we calculating values we don't need?\n",
"Trying to be smarter would be harder,\n",
"and might rely on operations that aren't as optimized as\n",
"matrix multiplication and addition.\n",
"Furthermore, it's \"only\" twice as many operations,\n",
"so it doesn't even show up in $O$-notation.\n",
"\n",
"A sample attention mask generated by our code base is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NXaWe6pT-9jV"
},
"outputs": [],
"source": [
"from text_recognizer.models import transformer_util\n",
"\n",
"\n",
"attention_mask = transformer_util.generate_square_subsequent_mask(100)\n",
"\n",
"ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n",
"plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n",
"print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This solves our causality problem,\n",
"but we still don't have positional information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZamUE4WIoGS2"
},
"source": [
"The standard technique\n",
"is to add alternating sines and cosines\n",
"of increasing frequency to the embeddings\n",
"(there are\n",
"[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n",
"most notably\n",
"[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n",
"Each position in the sequence is then uniquely identifiable\n",
"from the pattern of these values.\n",
"\n",
"> Furthermore, for the same reason that\n",
" [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n",
" translations, e.g. relative positions, are fairly easy to express as linear transformations\n",
" of sines and cosines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IDG2uOsaELU0"
},
"source": [
"We superimpose this positional information on our embeddings.\n",
"Note that because the model is residual,\n",
"this position information will be by default preserved\n",
"as it passes through the network,\n",
"so it doesn't need to be repeatedly added."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what this positional encoding looks like in our codebase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Zk62Q-a-1Ax"
},
"outputs": [],
"source": [
"PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n",
"\n",
"pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n",
"\n",
"ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n",
"print(pe[:4, :8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep2ClIWvqDms"
},
"source": [
"When we add the positional information to our embeddings,\n",
"both the embedding information and the positional information\n",
"is approximately preserved,\n",
"as can be visually assessed below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJuFjoCzC0Y4"
},
"outputs": [],
"source": [
"fake_embeddings = torch.randn_like(pe) * 0.5\n",
"\n",
"ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n",
"\n",
"fake_embeddings_with_pe = fake_embeddings + pe\n",
"\n",
"plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHIzBxDkEmH8"
},
"source": [
"A [similar technique](https://arxiv.org/abs/2103.06450)\n",
"is used to also incorporate positional information into the image embeddings,\n",
"which are flattened before being fed to the decoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HC1N85wl8dvn"
},
"source": [
"### Learn more about Transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lJwYxkjTk15t"
},
"source": [
"We're only able to give a flavor and an intuition for Transformers here.\n",
"\n",
"To improve your grasp on the nuts and bolts, check out the\n",
"[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n",
"which is surprisingly approachable,\n",
"as far as ML research papers go.\n",
"The\n",
"[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n",
"adds code and commentary to the original paper,\n",
"which makes it even more digestible.\n",
"For something even friendlier, check out the\n",
"[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n",
"by Jay Alammar, which has an accompanying\n",
"[video](https://youtu.be/-QH8fRhqFHM).\n",
"\n",
"Anthropic's work on\n",
"[Transformer Circuits](https://transformer-circuits.pub/),\n",
"summarized above, has some of the best material\n",
"for building theoretical understanding\n",
"and is still being updated with extensions and applications of the framework.\n",
"The\n",
"[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n",
"are a great aid for checking and building your understanding.\n",
"\n",
"But they are fairly math-heavy.\n",
"If you have more of a software engineering background, see\n",
"Transformer Circuits co-author Nelson Elhage's blog post\n",
"[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n",
"\n",
"For a gentler introduction to the intuition for Transformers,\n",
"check out Brandon Rohrer's\n",
"[Transformers From Scratch](https://e2eml.school/transformers.html)\n",
"tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qg7zntJES-aT"
},
"source": [
"An aside:\n",
"the matrix multiplications inside attention dominate\n",
"the big-$O$ runtime of Transformers.\n",
"So trying to make the attention mechanism more efficient, e.g. linear time,\n",
"has generated a lot of research\n",
"(review paper\n",
"[here](https://arxiv.org/abs/2009.06732)).\n",
"Despite drawing a lot of attention, so to speak,\n",
"at the time of writing in mid-2022, these methods\n",
"[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n",
"so it isn't likely to be worth the effort to spend time learning about them\n",
"unless you are a Transformer specialist."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCjXysEJ8g9_"
},
"source": [
"# Using Transformers to read paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsfKWnOvqjva"
},
"source": [
"Our simple convolutional model for text recognition from\n",
"[Lab 02b](https://fsdl.me/lab02b-colab)\n",
"could only handle cleanly-separated characters.\n",
"\n",
"It worked by sliding a LeNet-style CNN\n",
"over the image,\n",
"predicting a character for each step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "njLdzBqy-I90"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist_lines = text_recognizer.data.EMNISTLines()\n",
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"\n",
"# for sliding, see the for loop over range(S)\n",
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0N6yDBQq8ns"
},
"source": [
"But unfortunately for us, handwritten text\n",
"doesn't come in neatly-separated characters\n",
"of equal size, so we trained our model on synthetic data\n",
"designed to work with that model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiqUVbj0sxLr"
},
"source": [
"Now that we have a better model,\n",
"we can work with better data:\n",
"paragraphs from the\n",
"[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oizsOAcKs-dD"
},
"source": [
"The cell uses our `LightningDataModule`\n",
"to download and preprocess this data,\n",
"writing results to disk.\n",
"We can then spin up `DataLoader`s to give us batches.\n",
"\n",
"It can take several minutes to run the first time\n",
"on commodity machines,\n",
"with most time spent extracting the data.\n",
"On subsequent runs,\n",
"the time-consuming operations will not be repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uL9LHbjdsUbm"
},
"outputs": [],
"source": [
"iam_paragraphs = text_recognizer.data.IAMParagraphs()\n",
"\n",
"iam_paragraphs.prepare_data()\n",
"iam_paragraphs.setup()\n",
"xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n",
"\n",
"iam_paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBkFN9bbTm_S"
},
"source": [
"Now that we've got a batch,\n",
"let's take a look at some samples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hqaps8yxtBhU"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import wandb\n",
"\n",
"\n",
"def show(y):\n",
" y = y.detach().cpu() # bring back from accelerator if it's being used\n",
" return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 04: Experiment Management"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How experiment management brings observability to ML model development\n",
"- Which features of experiment management we use in developing the Text Recognizer\n",
"- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 4\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lab contains a large number of embedded iframes\n",
"that benefit from having a wide window.\n",
"The cell below makes the notebook as wide as your browser window\n",
"if `full_width` is set to `True`.\n",
"Full width is the default behavior in Colab,\n",
"so this cell is intended to improve the viewing experience in other Jupyter environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Why experiment management?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To understand why we need experiment management for ML development,\n",
"let's start by running an experiment.\n",
"\n",
"We'll train a new model on a new dataset,\n",
"using the training script `training/run_experiment.py`\n",
"introduced in [Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use a CNN encoder and Transformer decoder, as in\n",
"[Lab 03](https://fsdl.me/lab03-colab),\n",
"but with some changes so we can iterate faster.\n",
"We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n",
"[Lab02b](https://fsdl.me/lab02b-colab),\n",
"and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n",
"from text_recognizer.data import IAMLines # processed version split into individual lines\n",
"from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n",
"\n",
"\n",
"print(IAM.__doc__)\n",
"\n",
"# uncomment a line below for details on either class\n",
"# IAMLines?? \n",
"# LineCNNTransformer??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below will train a model on 10% of the data for two epochs.\n",
"\n",
"It takes up to a few minutes to run on commodity hardware,\n",
"including data download and preprocessing.\n",
"As it's running, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"import torch\n",
"\n",
"\n",
"gpus = int(torch.cuda.is_available()) \n",
"\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the model trains, we're calculating lots of metrics --\n",
"loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n",
"and reporting them to the terminal.\n",
"\n",
"This is achieved by the built-in `.log` method\n",
"([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n",
"of the `LightningModule`,\n",
"and it is a very straightforward way to get basic information about your experiment as it's running\n",
"without leaving the context where you're running it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Learning to read\n",
"[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n",
"is something of a rite of passage for MLEs, but\n",
"let's consider what we can't see here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We're missing all metric values except the most recent --\n",
"we can see them as they stream in, but they're constantly overwritten.\n",
"We also can't associate them with timestamps, steps, or epochs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We also don't see any system metrics.\n",
"We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n",
"without launching a separate process.\n",
"And even if we do, those values will also not be saved and timestamped,\n",
"so we can't correlate them with other things during training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- As we continue to run experiments, changing code and opening new terminals,\n",
"even the information we have or could figure out now will disappear.\n",
"Say you spot a weird error message during training,\n",
"but your session ends and the stdout is gone,\n",
"so you don't know exactly what it was.\n",
"Can you recreate the error?\n",
"Which git branch and commit were you on?\n",
"Did you have any uncommitted changes? Which arguments did you pass?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Also, model checkpoints containing the parameter values have been saved to disk.\n",
"Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n",
"As we run more and more experiments,\n",
"we'll want to slice and dice them to see if,\n",
"say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to save and log all of this information, and more, in order to make our model training\n",
"[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n",
"in short, so that we can understand, make decisions about, and debug our model training\n",
"by looking at logs and source code, without having to recreate it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n",
"1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n",
"2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n",
"3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n",
"4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local Experiment Tracking with Tensorboard"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n",
"and it makes logging very easy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n",
"\n",
"It can also use a logger to push information elsewhere.\n",
"\n",
"By default, we use\n",
"[TensorBoard](https://www.tensorflow.org/tensorboard)\n",
"via the Lightning `TensorBoardLogger`,\n",
"which has been saving results to the local disk.\n",
"\n",
"Let's find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest experiment's directory\n",
"# by hand, you can just copy and paste it from the terminal\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n",
"filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n",
"latest_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!ls -lh {latest_log}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, we need to launch a TensorBoard server --\n",
"much like we need to launch a Jupyter server to use Jupyter notebooks.\n",
"\n",
"The cells below load an extension that lets you use TensorBoard inside of a notebook\n",
"the same way you'd use it from the command line, and then launch it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n",
"\n",
"port = 11717 # pick an open port on your machine\n",
"host = \"0.0.0.0\" # allow connections from the internet\n",
" # watch out! make sure you turn TensorBoard off\n",
"\n",
"%tensorboard --logdir {latest_log} --port {port} --host {host}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see some charts of metrics over time along with some charting controls.\n",
"\n",
"You can click around in this interface and explore it if you'd like,\n",
"but in the next section, we'll see that there are better tools for experiment management."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've run many experiments on this machine,\n",
"you can see all of their results by pointing TensorBoard\n",
"at the whole `lightning_logs` directory,\n",
"rather than just one experiment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For large numbers of experiments, the management experience is not great --\n",
"it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n",
"\n",
"It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n",
"which are important workflows as applications mature and teams grow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n",
"\n",
"If you run into any issues with the above cells failing to launch,\n",
"especially across iterations of this lab, run this cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorboard.manager\n",
"\n",
"# get the process IDs for all tensorboard instances\n",
"pids = [tb.pid for tb in tensorboard.manager.get_all()]\n",
"\n",
"done_with_tensorboard = False\n",
"\n",
"if done_with_tensorboard:\n",
" # kill processes\n",
" for pid in pids:\n",
" !kill {pid} 2> /dev/null\n",
" \n",
" # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n",
" !rm -rf {tensorboard.manager._get_info_dir()}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiment Management with Weights & Biases"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How do we manage experiments when we hit the limits of local TensorBoard?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is powerful and flexible and very scalable,\n",
"but running it requires engineering effort and babysitting --\n",
"you're running a database, writing data to it,\n",
"and layering a web application over it.\n",
"\n",
"This is a fairly common workflow for web developers,\n",
"but not so much for ML engineers.\n",
"\n",
"You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n",
"and it's as simple as running the command `tensorboard dev upload`\n",
"pointed at your logging directory.\n",
"\n",
"But there are strict limits to this free service:\n",
"1GB of tensor data and 1GB of binary data.\n",
"A single Text Recognizer model checkpoint is ~100MB,\n",
"and that's not particularly large for a useful model.\n",
"\n",
"Furthermore, all data is public,\n",
"so if you upload the inputs and outputs of your model,\n",
"anyone who finds the link can see them.\n",
"\n",
"Overall, tensorboard.dev works very well for certain academic and open projects\n",
"but not for industrial ML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To avoid that narrow permissions and limits issue,\n",
"you could use [git LFS](https://git-lfs.github.com/)\n",
"to track the binary data and tensor data,\n",
"which is more likely to be sensitive than metrics.\n",
"\n",
"The Hugging Face ecosystem uses TensorBoard and git LFS.\n",
"\n",
"It includes the Hugging Face Hub, a git server much like GitHub,\n",
"but designed first and foremost for collaboration on models and datasets,\n",
"rather than collaboration on code.\n",
"For example, the Hugging Face Hub\n",
"[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n",
"and officially has\n",
"[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n",
"avoiding the\n",
"[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n",
"that make using git LFS with GitHub expensive.\n",
"\n",
"However, we prefer to avoid mixing software version control and experiment management.\n",
"\n",
"First, using the Hub requires maintaining an additional git remote,\n",
"which is a hard ask for many engineering teams.\n",
"\n",
"Secondly, git-style versioning is an awkward fit for logging --\n",
"is it really sensible to create a new commit for each logging event while you're watching live?\n",
"\n",
"Instead, we prefer to use systems that solve experiment management with _databases_."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n",
"The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n",
"tool is [MLflow](https://github.com/mlflow/mlflow/)\n",
"and there are a number of\n",
"[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n",
"\n",
"These tools generally avoid any need to worry about hosting\n",
"(unless data governance rules require a self-hosted version).\n",
"\n",
"For a sampling of publicly-posted opinions on experiment management tools,\n",
"see these discussions from Reddit:\n",
"\n",
"- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n",
"- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n",
"\n",
"Among these tools, the FSDL recommendation is\n",
"[Weights & Biases](https://wandb.ai),\n",
"which we believe offers\n",
"- the best user experience, both in the Python SDKs and in the graphical interface\n",
"- the best integrations with other tools,\n",
"including\n",
"[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n",
"[Keras](https://docs.wandb.ai/guides/integrations/keras),\n",
"[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n",
"and even\n",
"[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n",
"and\n",
"- the best tools for collaboration.\n",
"\n",
"Below, we'll take care to point out which logging and management features\n",
"are available via generic interfaces in Lightning and which are W&B-specific."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"print(wandb.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding it to our experiment running code is extremely easy,\n",
"relative to the features we get, which is\n",
"one of the main selling points of W&B.\n",
"\n",
"We get most of our new experiment management features just by changing a single variable, `logger`, from\n",
"`TensorboardLogger` to `WandbLogger`\n",
"and adding two lines of code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll see what each of these lines does for us below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this logger is built into and maintained by PyTorch Lightning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"\n",
"WandbLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to complete the rest of this notebook,\n",
"you'll need a Weights & Biases account.\n",
"\n",
"As with GitHub the free tier, for personal, academic, and open source work,\n",
"is very generous.\n",
"\n",
"The Text Recognizer project will fit comfortably within the free tier.\n",
"\n",
"Run the cell below and follow the prompts to log in or create an account or go\n",
"[here](https://wandb.ai/signup)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the cell below to launch an experiment tracked with Weights & Biases.\n",
"\n",
"The experiment can take between 3 and 10 minutes to run.\n",
"In that time, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n",
" --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1\n",
" \n",
"last_expt = wandb.run\n",
"\n",
"wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see some new things in our output.\n",
"\n",
"For example, there's a note from `wandb` that the data is saved locally\n",
"and also synced to their servers.\n",
"\n",
"There's a link to a webpage for viewing the logged data and a name for our experiment --\n",
"something like `dandy-sunset-1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The local logging and cloud syncing happens with minimal impact on performance,\n",
"because `wandb` launches a separate process to listen for events and upload them.\n",
"\n",
"That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, head to the link in the notebook output\n",
"that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n",
"\n",
"There's no need to wait for training to finish.\n",
"\n",
"The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(last_expt.url)\n",
"IFrame(last_expt.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have landed on the run page\n",
"([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n",
"which collects up all of the information for a single experiment into a collection of tabs.\n",
"\n",
"We'll work through these tabs from top to bottom.\n",
"\n",
"Each header is also a link to the documentation for a tab."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n",
"This tab has an icon that looks like `(i)` or 🛈.\n",
"\n",
"The top section of this tab has high-level information about our run:\n",
"- Timing information, like start time and duration\n",
"- System hardware, hostname, and basic environment info\n",
"- Git repository link and state\n",
"\n",
"This information is collected and logged automatically.\n",
"\n",
"The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n",
"and summary metrics.\n",
"\n",
"Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n",
"\n",
"This tab has a line plot icon, something like 📈.\n",
"\n",
"It's also the default page you land on when looking at a W&B run.\n",
"\n",
"Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n",
"\n",
"Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n",
"\n",
"We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n",
"This tab has computer chip icon.\n",
"\n",
"It contains\n",
"- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n",
"- CPU metrics: memory usage, utilization, thread counts\n",
"- Disk and network I/O levels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n",
"This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n",
"\n",
"The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n",
"This tab has an icon that looks like a stylized command prompt, `>_`.\n",
"\n",
"It contains information that was printed to the stdout.\n",
"\n",
"This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n",
"\n",
"Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelSummary\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n",
"\n",
"For more on Lightning `Callback`s, see\n",
"[Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n",
"This tab has a stylized document icon, something like 📄.\n",
"\n",
"You can use this tab to view any files saved with the `wandb.save`.\n",
"\n",
"For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n",
"which we'll discuss shortly.\n",
"\n",
"But a few pieces of information automatically collected by W&B end up in this tab.\n",
"\n",
"Some highlights:\n",
" - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n",
" - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n",
"This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n",
"\n",
"This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n",
"\n",
"We store two kinds of binary files\n",
" - `run_table`s of model inputs and outputs\n",
" - `model` checkpoints\n",
"\n",
"We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interactive Tables of Logged Media"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Returning to the Charts tab,\n",
"notice that we have model inputs and outputs logged in structured tables\n",
"under the train, validation, and test sections.\n",
"\n",
"These tables are interactive as well\n",
"([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n",
"They support basic exploratory data analysis and are compatible with W&B's collaboration features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to charts in our run page, these tables also have their own pages inside the W&B web app."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n",
"table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n",
"\n",
"print(table_data_url)\n",
"IFrame(src=table_data_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Getting this to work requires more effort and more W&B-specific code\n",
"than the other features we've seen so far.\n",
"\n",
"We'll briefly explain the implementation here, for those who are interested.\n",
"\n",
"We use a custom Lightning `Callback`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n",
"\n",
"\n",
"ImageToTextTableLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n",
"\n",
"The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n",
"\n",
"This behavior is sensible for metrics, which are low overhead, but not so much for media,\n",
"where we'd rather subsample and avoid holding on to too much information.\n",
"\n",
"So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n",
"\n",
"The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.lit_models.base import BaseImageToTextLitModel\n",
"\n",
"BaseImageToTextLitModel.add_on_logged_batches??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Projects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Everything we've seen so far has been related to a single run or experiment.\n",
"\n",
"Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n",
"\n",
"We organize our runs into \"projects\" and view them on the W&B \"project page\" \n",
"([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n",
"\n",
"By default in the Lightning integration, the project name is determined based on directory information.\n",
"This default can be over-ridden in the code when creating a `WandbLogger`,\n",
"but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what the project page looks like for a longer-running project with lots of experiments.\n",
"\n",
"The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n",
"\n",
"print(project_url)\n",
"IFrame(src=project_url, width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n",
"\n",
"We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Artifacts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Beyond logging metrics and metadata from runs,\n",
"we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below pulls up all of the artifacts associated with the experiment we just ran."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Click on one of the `model` checkpoints -- the specific version doesn't matter.\n",
"\n",
"There are a number of tabs here.\n",
"\n",
"The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n",
"\n",
"The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n",
"which are added by default by the `WandbLogger`.\n",
"\n",
"The \"Files\" tab contains the actual file contents of the artifact.\n",
"\n",
"On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n",
"including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n",
"\n",
"You can click on these to explore the different versions and even directly compare them.\n",
"\n",
"If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Artifact storage is part of the W&B free tier.\n",
"\n",
"The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n",
"\n",
"The former is sufficient to store ~700 model checkpoints for the Text Recognizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can track your data storage and compare it to your limits at this URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n",
"\n",
"print(storage_tracker_url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Programmatic Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also programmatically access our data and metadata via the `wandb` API\n",
"([docs](https://docs.wandb.ai/guides/track/public-api-guide)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wb_api = wandb.Api()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run = wb_api.run(\"/\".join( # fetch a run given\n",
" [last_expt.entity, # the user or org it was logged to\n",
" last_expt.project, # the \"project\", usually one of several per repo/application\n",
" last_expt.id] # and a unique ID\n",
"))\n",
"\n",
"hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n",
"\n",
"hist.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hist.groupby(\"epoch\")[\"train/loss\"].mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this includes the artifacts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# which artifacts where created and logged?\n",
"artifacts = run.logged_artifacts()\n",
"\n",
"for artifact in artifacts:\n",
" print(f\"artifact of type {artifact.type}: {artifact.name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thanks to our `ImageToTextTableLogger`,\n",
"we can easily recreate training or validation data that came out of our `DataLoader`s,\n",
"which is normally ephemeral:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n",
"artifact_dir = Path(artifact.download(root=\"training/logs\"))\n",
"image_dir = artifact_dir / \"media\" / \"images\"\n",
"\n",
"images = [path for path in image_dir.iterdir()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"from IPython.display import Image\n",
"\n",
"Image(str(random.choice(images)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Advanced W&B API Usage: MLOps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the strengths of a well-instrumented experiment tracking system is that it allows\n",
"automatic relation of information:\n",
"what were the inputs when this model's gradient spiked?\n",
"Which models have been trained on this dataset,\n",
"and what was their performance?\n",
"\n",
"Having access and automation around this information is necessary for \"MLOps\",\n",
"which applies contemporary DevOps principles to ML projects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cells below pull down the training data\n",
"for the model currently running the FSDL Text Recognizer app.\n",
"\n",
"This is just intended as a demonstration of what's possible,\n",
"so don't worry about understanding every piece of this,\n",
"and feel free to skip past it.\n",
"\n",
"MLOps is still a nascent field, and these tools and workflows are likely to change.\n",
"\n",
"For example, just before the course launched, W&B released a\n",
"[Model Registry layer](https://docs.wandb.ai/guides/models)\n",
"on top of artifact logging that aims to improve the developer experience for these workflows."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start from the same project we looked at in the project view:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n",
"\n",
"text_recognizer_project "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we search it for the text recognizer model currently being used in production:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# collect all versions of the text-recognizer ever put into production by...\n",
"\n",
"for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n",
" if art_type.name == \"prod-ready\": # for the prod-ready type\n",
" # and grabbing the text-recognizer\n",
" production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n",
"\n",
"# and then get the one that's currently being tested in CI by...\n",
"for text_recognizer in production_text_recognizers:\n",
" if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n",
" in_prod_text_recognizer = text_recognizer\n",
"\n",
"# view its metadata at the url or in the notebook\n",
"in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n",
"\n",
"print(in_prod_text_recognizer_url)\n",
"IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From its metadata, we can get information about how it was \"staged\" to be put into production,\n",
"and in particular which model checkpoint was used:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staging_run = in_prod_text_recognizer.logged_by()\n",
"\n",
"training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n",
"training_ckpt.name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That checkpoint was logged by a training experiment, which is available as metadata.\n",
"\n",
"We can look at the training run for that model, either here in the notebook or at its URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"training_run = training_ckpt.logged_by()\n",
"print(training_run.url)\n",
"IFrame(src=training_run.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And from there, we can access logs and metadata about training,\n",
"confident that we are working with the model that is actually in production.\n",
"\n",
"For example, we can pull down the data we logged and analyze it locally."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_results = training_run.history(samples=10000)\n",
"training_results.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n",
"training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 10\n",
"training_results[\"validation/loss\"].dropna().iloc[10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The charts and webpages in Weights & Biases\n",
"are substantially more useful than ephemeral stdouts or raw logs on disk.\n",
"\n",
"If you're spun up on the project,\n",
"they accelerate debugging, exploration, and discovery.\n",
"\n",
"If not, they're not so much useful as they are overwhelming.\n",
"\n",
"We need to synthesize the raw logged data into information.\n",
"This helps us communicate our work with other stakeholders,\n",
"preserve knowledge and prevent repetition of work,\n",
"and surface insights faster.\n",
"\n",
"These workflows are supported by the W&B Reports feature\n",
"([docs here](https://docs.wandb.ai/guides/reports)),\n",
"which mix W&B charts and tables with explanatory markdown text and embeds.\n",
"\n",
"Below are some common report patterns and\n",
"use cases and examples of each."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of the examples are from the FSDL Text Recognizer project.\n",
"You can find more of them\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n",
"where we've organized them into a report!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dashboard Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dashboards are a structured subset of the output from one or more experiments,\n",
"designed for quickly surfacing issues or insights,\n",
"like an accuracy or performance regression\n",
"or a change in the data distribution.\n",
"\n",
"Use cases:\n",
"- show the basic state of ongoing experiment\n",
"- compare one experiment to another\n",
"- select the most important charts so you can spin back up into context on a project more quickly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n",
"\n",
"IFrame(src=dashboard_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pull Request Documentation Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In most software codebases,\n",
"pull requests are a key focal point\n",
"for units of work that combine\n",
"short-term communication and long-term information tracking.\n",
"\n",
"In ML codebases, it's more difficult to bring\n",
"sufficient information together to make PRs as useful.\n",
"At FSDL, we like to add documentary\n",
"reports with one or a small number of charts\n",
"that connect logged information in the experiment management system\n",
"to state in the version control software.\n",
"\n",
"Use cases:\n",
"- communication of results within a team, e.g. code review\n",
"- record-keeping that links pull request pages to raw logged info and makes it discoverable\n",
"- improving confidence in PR correctness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n",
"\n",
"IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Blog Post Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With sufficient effort, the logged data in the experiment management system\n",
"can be made clear enough to be consumed,\n",
"sufficiently contextualized to be useful outside the team, and\n",
"even beautiful.\n",
"\n",
"The result is a report that's closer to a blog post than a dashboard or internal document.\n",
"\n",
"Use cases:\n",
"- communication between teams or vertically in large organizations\n",
"- external technical communication for branding and recruiting\n",
"- attracting users or contributors\n",
"\n",
"Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n",
"[Boris Dayma](https://twitter.com/borisdayma)\n",
"and others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n",
"\n",
"IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperparameter Optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Many of our choices, like the depth of our network, the nonlinearities of our layers,\n",
"and the learning rate and other parameters of our optimizer, cannot be\n",
"([easily](https://arxiv.org/abs/1606.04474))\n",
"chosen by descent of the gradient of a loss function.\n",
"\n",
"But these parameters that impact the values of the parameters\n",
"we directly optimize with gradients, or _hyperparameters_,\n",
"can still be optimized,\n",
"essentially by trying options and selecting the values that worked best.\n",
"\n",
"In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n",
"\n",
"Expending more compute can squeeze small amounts of additional validation or test performance\n",
"that makes for impressive results on leaderboards but typically doesn't translate\n",
"into better user experience.\n",
"\n",
"In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n",
"built into your other tooling.\n",
"\n",
"Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n",
"([docs](https://docs.wandb.ai/guides/sweeps)).\n",
"\n",
"It also supports a number of more advanced tools, like\n",
"[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n",
"for early termination of poorly-performing runs.\n",
"\n",
"We can use the same training script and we don't need to run an optimization server.\n",
"\n",
"We just need to write a configuration yaml file\n",
"([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n",
"like the one below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile training/simple-overfit-sweep.yaml\n",
"# first we specify what we're sweeping\n",
"# we specify a program to run\n",
"program: training/run_experiment.py\n",
"# we optionally specify how to run it, including setting default arguments\n",
"command: \n",
" - ${env}\n",
" - ${interpreter}\n",
" - ${program}\n",
" - \"--wandb\"\n",
" - \"--overfit_batches\"\n",
" - \"1\"\n",
" - \"--log_every_n_steps\"\n",
" - \"25\"\n",
" - \"--max_epochs\"\n",
" - \"100\"\n",
" - \"--limit_test_batches\"\n",
" - \"0\"\n",
" - ${args} # these arguments come from the sweep parameters below\n",
"\n",
"# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n",
"method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n",
"metric:\n",
" name: train/loss\n",
" goal: minimize\n",
"parameters: \n",
" # LineCNN hyperparameters\n",
" window_width:\n",
" values: [8, 16, 32, 64]\n",
" window_stride:\n",
" values: [4, 8, 16, 32]\n",
" # Transformer hyperparameters\n",
" tf_layers:\n",
" values: [1, 2, 4, 8]\n",
" # we can also fix some values, just like we set default arguments\n",
" gpus:\n",
" value: 1\n",
" model_class:\n",
" value: LineCNNTransformer\n",
" data_class:\n",
" value: IAMLines\n",
" loss:\n",
" value: transformer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the config we launch a \"controller\":\n",
"a lightweight process that just decides what hyperparameters to try next\n",
"and coordinates the heavierweight training.\n",
"\n",
"This lives on the W&B servers, so there are no headaches about opening ports for communication,\n",
"cleaning up when it's done, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n",
"simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we can launch an \"agent\" to follow the orders of the controller:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# interrupt twice to terminate this cell if it's running too long,\n",
"# it can be over 15 minutes with some hyperparameters\n",
"\n",
"!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n",
"\n",
"If not provided, the agent will run forever for random or Bayesian sweeps\n",
"or until the sweep is terminated, which can be done from the W&B interface."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The agents make for a slick workflow for distributing sweeps across GPUs.\n",
"\n",
"We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n",
"which controls which GPUs are accessible by a process, to launch\n",
"parallel agents on separate GPUs on the same machine."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n",
"# open another terminal\n",
"CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n",
"# and so on\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We include optional exercises with the labs for learners who want to dive deeper on specific topics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟Contribute to a hyperparameter search."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n",
"\n",
"There are ~10,000,000 potential hyperparameter combinations,\n",
"and each takes 30 minutes to test,\n",
"so checking each possibility will take over 500 years of compute time.\n",
"Best get cracking then!\n",
"\n",
"Run the cell below to pull up a dashboard and print the URL where you can check on the current status."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_entity = \"fullstackdeeplearning\"\n",
"sweep_project = \"fsdl-line-recognizer-2022\"\n",
"sweep_id = \"e0eo43eu\"\n",
"sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n",
"\n",
"print(sweep_url)\n",
"IFrame(src=sweep_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also retrieve information about the sweep from the API,\n",
"including the hyperparameters being swept over."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hyperparams = sweep_info.config[\"parameters\"]\n",
"hyperparams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you'd like to contribute to this sweep,\n",
"run the cell below after changing the count to a number greater than 0.\n",
"\n",
"Each iteration runs for 30 minutes if it does not crash,\n",
"e.g. due to out-of-memory errors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"count = 0 # off by default, increase it to join in!\n",
"\n",
"if count:\n",
" !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Write some manual logging in `wandb`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the FSDL Text Recognizer codebase,\n",
"we almost exclusively log to W&B through Lightning,\n",
"rather than through the `wandb` Python SDK.\n",
"\n",
"If you're interested in learning how to use W&B directly, e.g. with another training framework,\n",
"try out this quick exercise that introduces the key players in the SDK."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n",
"\n",
"It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n",
"\n",
"Add W&B metric and artifact logging to this cell:\n",
"- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n",
"- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import os\n",
"import random\n",
"\n",
"import wandb\n",
"\n",
"\n",
"os.makedirs(\"logs\", exist_ok=True)\n",
"\n",
"project = \"trying-wandb\"\n",
"config = {\"steps\": 50}\n",
"\n",
"\n",
"with wandb.init(project=project, config=config) as run:\n",
" steps = wandb.config[\"steps\"]\n",
" \n",
" for ii in range(steps):\n",
" loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n",
" \n",
" with open(\"logs/hello.txt\", \"w\") as f:\n",
" f.write(\"hello from wandb, my dudes!\")\n",
" \n",
" run_id = run.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hello_run = wb_api.run(f\"{project}/{run_id}\")\n",
"\n",
"# check for logged loss data\n",
"if \"loss\" not in hello_run.history().keys():\n",
" print(\"loss not logged 🥲\")\n",
"else:\n",
" print(\"loss logged successfully 🥞\")\n",
" if len(hello_run.history()[\"loss\"]) != steps:\n",
" print(\"loss not logged on all steps 🥲\")\n",
" else:\n",
" print(\"loss logged on all steps 🥞\")\n",
"\n",
"artifacts = hello_run.logged_artifacts()\n",
"\n",
"# check for artifact with the right name\n",
"if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n",
" print(\"hello artifact not logged 🥲\")\n",
"else:\n",
" print(\"hello artifact logged successfully 🥞\")\n",
" # check for the file inside the artifacts\n",
" if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n",
" print(\"could not find hello.txt 🥲\")\n",
" else:\n",
" print(\"hello.txt logged successfully 🥞\")\n",
" \n",
" \n",
"hello_run"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you observe interesting phenomena during training,\n",
"from promising hyperparameter combos to software bugs to strange model behavior,\n",
"turn the charts into a W&B report and share it with the FSDL community or\n",
"[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n",
"with a link to them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# check the sweep_info.config above to see the model and data hyperparameters\n",
"# read through the --help output for all potential arguments\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n",
" --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n",
" --help # remove this line to run an experiment instead of printing help\n",
" \n",
"last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟🌟🌟 Add logging of tensor statistics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to logging model inputs and outputs as human-interpretable media,\n",
"it's also frequently useful to see information about their numerical values."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're interested in learning more about metric calculation and logging with Lightning,\n",
"use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n",
"to add tensor statistic logging to the `LineCNNTransformer`.\n",
"\n",
"`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n",
"\n",
"All three are useful, but start by adding just one."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n",
"- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n",
" - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n",
"- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n",
" - Base your code on the calculation and logging of the `val_cer` metric.\n",
" - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ",
"collapsed_sections": [],
"name": "lab04_experiments.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab04/text_recognizer/__init__.py
================================================
"""Modules for creating and running a text recognizer."""
================================================
FILE: lab04/text_recognizer/callbacks/__init__.py
================================================
from .model import ModelSizeLogger
from .optim import LearningRateMonitor
from . import imtotext
from .imtotext import ImageToTextTableLogger as ImageToTextLogger
================================================
FILE: lab04/text_recognizer/callbacks/imtotext.py
================================================
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
from .util import check_and_warn
class ImageToTextTableLogger(pl.Callback):
"""Logs the inputs and outputs of an image-to-text model to Weights & Biases."""
def __init__(self, max_images_to_log=32, on_train=True):
super().__init__()
self.max_images_to_log = min(max(max_images_to_log, 1), 32)
self.on_train = on_train
self._required_keys = ["gt_strs", "pred_strs"]
@rank_zero_only
def on_train_batch_end(self, trainer, module, output, batch, batch_idx):
if self.on_train:
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_table", "image-to-text table"):
return
else:
self._log_image_text_table(trainer, output, batch, "train/predictions")
@rank_zero_only
def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_table", "image-to-text table"):
return
else:
self._log_image_text_table(trainer, output, batch, "validation/predictions")
def _log_image_text_table(self, trainer, output, batch, key):
xs, _ = batch
gt_strs = output["gt_strs"]
pred_strs = output["pred_strs"]
mx = self.max_images_to_log
xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx]
xs = [wandb.Image(x) for x in xs]
rows = zip(*[xs, gt_strs, pred_strs])
columns = ["input_image", "ground_truth_string", "predicted_string"]
trainer.logger.log_table(key=key, columns=columns, data=list(rows))
def has_metrics(self, output):
return all(key in output.keys() for key in self._required_keys)
class ImageToTextCaptionLogger(pl.Callback):
"""Logs the inputs and outputs of an image-to-text model to Weights & Biases."""
def __init__(self, max_images_to_log=32, on_train=True):
super().__init__()
self.max_images_to_log = min(max(max_images_to_log, 1), 32)
self.on_train = on_train
self._required_keys = ["gt_strs", "pred_strs"]
@rank_zero_only
def on_train_batch_end(self, trainer, module, output, batch, batch_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "train/predictions")
@rank_zero_only
def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "validation/predictions")
@rank_zero_only
def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "test/predictions")
def _log_image_text_caption(self, trainer, output, batch, key):
xs, _ = batch
gt_strs = output["gt_strs"]
pred_strs = output["pred_strs"]
mx = self.max_images_to_log
xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx]
trainer.logger.log_image(key, xs, caption=pred_strs)
def has_metrics(self, output):
return all(key in output.keys() for key in self._required_keys)
================================================
FILE: lab04/text_recognizer/callbacks/model.py
================================================
import os
from pathlib import Path
import tempfile
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import torch
from .util import check_and_warn, logging
try:
import torchviz
has_torchviz = True
except ImportError:
has_torchviz = False
class ModelSizeLogger(pl.Callback):
"""Logs information about model size (in parameters and on disk)."""
def __init__(self, print_size=True):
super().__init__()
self.print_size = print_size
@rank_zero_only
def on_fit_start(self, trainer, module):
self._run(trainer, module)
def _run(self, trainer, module):
metrics = {}
metrics["mb_disk"] = self.get_model_disksize(module)
metrics["nparams"] = count_params(module)
if self.print_size:
print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB")
metrics = {f"size/{key}": value for key, value in metrics.items()}
trainer.logger.log_metrics(metrics, step=-1)
@staticmethod
def get_model_disksize(module):
"""Determine the model's size on disk by saving it to disk."""
with tempfile.NamedTemporaryFile() as f:
torch.save(module.state_dict(), f)
size_mb = os.path.getsize(f.name) / 1e6
return size_mb
class GraphLogger(pl.Callback):
"""Logs a compute graph as an image."""
def __init__(self, output_key="logits"):
super().__init__()
self.graph_logged = False
self.output_key = output_key
if not has_torchviz:
raise ImportError("GraphLogCallback requires torchviz." "")
@rank_zero_only
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx):
if not self.graph_logged:
try:
outputs = outputs[0][0]["extra"]
self.log_graph(trainer, module, outputs[self.output_key])
except KeyError:
logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}")
self.graph_logged = True
@staticmethod
def log_graph(trainer, module, outputs):
if check_and_warn(trainer.logger, "log_image", "graph"):
return
params_dict = dict(list(module.named_parameters()))
graph = torchviz.make_dot(outputs, params=params_dict)
graph.format = "png"
fname = Path(trainer.logger.experiment.dir) / "graph"
graph.render(fname)
fname = str(fname.with_suffix("." + graph.format))
trainer.logger.log_image(key="graph", images=[fname])
def count_params(module):
"""Counts the number of parameters in a Torch Module."""
return sum(p.numel() for p in module.parameters())
================================================
FILE: lab04/text_recognizer/callbacks/optim.py
================================================
import pytorch_lightning as pl
KEY = "optimizer"
class LearningRateMonitor(pl.callbacks.LearningRateMonitor):
"""Extends Lightning's LearningRateMonitor with a prefix.
Logs the learning rate during training. See the docs for
pl.callbacks.LearningRateMonitor for details.
"""
def _add_prefix(self, *args, **kwargs) -> str:
return f"{KEY}/" + super()._add_prefix(*args, **kwargs)
================================================
FILE: lab04/text_recognizer/callbacks/util.py
================================================
import logging
logging.basicConfig(level=logging.WARNING)
def check_and_warn(logger, attribute, feature):
if not hasattr(logger, attribute):
warn_no_attribute(feature, attribute)
return True
def warn_no_attribute(blocked_feature, missing_attribute):
logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.")
================================================
FILE: lab04/text_recognizer/data/__init__.py
================================================
"""Module containing submodules for each dataset.
Each dataset is defined as a class in that submodule.
The datasets should have a .config method that returns
any configuration information needed by the model.
Most datasets define their constants in a submodule
of the metadata module that is parallel to this one in the
hierarchy.
"""
from .util import BaseDataset
from .base_data_module import BaseDataModule
from .mnist import MNIST
from .emnist import EMNIST
from .emnist_lines import EMNISTLines
from .iam_paragraphs import IAMParagraphs
from .iam_lines import IAMLines
================================================
FILE: lab04/text_recognizer/data/base_data_module.py
================================================
"""Base DataModule class."""
import argparse
import os
from pathlib import Path
from typing import Collection, Dict, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch.utils.data import ConcatDataset, DataLoader
from text_recognizer import util
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.shared as metadata
def load_and_print_info(data_module_class) -> None:
"""Load EMNISTLines and print info."""
parser = argparse.ArgumentParser()
data_module_class.add_to_argparse(parser)
args = parser.parse_args()
dataset = data_module_class(args)
dataset.prepare_data()
dataset.setup()
print(dataset)
def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return filename
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename
BATCH_SIZE = 128
NUM_AVAIL_CPUS = len(os.sched_getaffinity(0))
NUM_AVAIL_GPUS = torch.cuda.device_count()
# sensible multiprocessing defaults: at most one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS
# but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS
class BaseDataModule(pl.LightningDataModule):
"""Base for all of our LightningDataModules.
Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html
"""
def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS)
self.on_gpu = isinstance(self.args.get("gpus", None), (str, int))
# Make sure to set the variables below in subclasses
self.input_dims: Tuple[int, ...]
self.output_dims: Tuple[int, ...]
self.mapping: Collection
self.data_train: Union[BaseDataset, ConcatDataset]
self.data_val: Union[BaseDataset, ConcatDataset]
self.data_test: Union[BaseDataset, ConcatDataset]
@classmethod
def data_dirname(cls):
return metadata.DATA_DIRNAME
@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size",
type=int,
default=BATCH_SIZE,
help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.",
)
parser.add_argument(
"--num_workers",
type=int,
default=DEFAULT_NUM_WORKERS,
help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.",
)
return parser
def config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping}
def prepare_data(self, *args, **kwargs) -> None:
"""Take the first steps to prepare data for use.
Use this method to do things that might write to disk or that need to be done only from a single GPU
in distributed settings (so don't set state `self.x = y`).
"""
def setup(self, stage: Optional[str] = None) -> None:
"""Perform final setup to prepare data for consumption by DataLoader.
Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
def train_dataloader(self):
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def val_dataloader(self):
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def test_dataloader(self):
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
================================================
FILE: lab04/text_recognizer/data/emnist.py
================================================
"""EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present."""
import json
import os
from pathlib import Path
import shutil
from typing import Sequence
import zipfile
import h5py
import numpy as np
import toml
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset, split_dataset
import text_recognizer.metadata.emnist as metadata
from text_recognizer.stems.image import ImageStem
from text_recognizer.util import temporary_working_directory
NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS
RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME
METADATA_FILENAME = metadata.METADATA_FILENAME
DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class.
TRAIN_FRAC = 0.8
class EMNIST(BaseDataModule):
"""EMNIST dataset of handwritten characters and digits.
"The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
def __init__(self, args=None):
super().__init__(args)
self.mapping = metadata.MAPPING
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
self.transform = ImageStem()
self.input_dims = metadata.DIMS
self.output_dims = metadata.OUTPUT_DIMS
def prepare_data(self, *args, **kwargs) -> None:
if not os.path.exists(PROCESSED_DATA_FILENAME):
_download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_trainval = f["x_train"][:]
self.y_trainval = f["y_train"][:].squeeze().astype(int)
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self):
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n"
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
def _download_and_process_emnist():
metadata = toml.load(METADATA_FILENAME)
_download_raw_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path):
print("Unzipping EMNIST...")
with temporary_working_directory(dirname):
with zipfile.ZipFile(filename, "r") as zf:
zf.extract("matlab/emnist-byclass.mat")
from scipy.io import loadmat
# NOTE: If importing at the top of module, would need to list scipy as prod dependency.
print("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
# NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices
if SAMPLE_TO_BALANCE:
print("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
print("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
print("Saving essential dataset parameters to text_recognizer/data...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(list(mapping.values()))
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
with open(ESSENTIALS_FILENAME, "w") as f:
json.dump(essentials, f)
print("Cleaning up...")
shutil.rmtree("matlab")
def _sample_to_balance(x, y):
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(42)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_inds = []
for label in np.unique(y.flatten()):
inds = np.where(y == label)[0]
sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
all_sampled_inds.append(sampled_inds)
ind = np.concatenate(all_sampled_inds)
x_sampled = x[ind]
y_sampled = y[ind]
return x_sampled, y_sampled
def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset
iam_characters = [
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
# Also add special tokens:
# - CTC blank token at index 0
# - Start token at index 1
# - End token at index 2
# - Padding token at index 3
# NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this!
return ["", "
", "", " and ", *tokens, " and ", *tokens, ""]
self.end_index = self.inverse_mapping["",
""]
self.end_token = inverse_mapping[""]
self.end_token = inverse_mapping[""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab05/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab05/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 03: Transformers and Paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The fundamental reasons why the Transformer is such\n",
"a powerful and popular architecture\n",
"- Core intuitions for the behavior of Transformer architectures\n",
"- How to use a convolutional encoder and a Transformer decoder to recognize\n",
"entire paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 3\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Transformers?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal in building a text recognizer is to take a two-dimensional image\n",
"and convert it into a one-dimensional sequence of characters\n",
"from some alphabet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional neural networks,\n",
"discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n",
"are great at encoding images,\n",
"taking them from their raw pixel values\n",
"to a more semantically meaningful numerical representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But how do we go from that to a sequence of letters?\n",
"And what's especially tricky:\n",
"the number of letters in an image is separable from its size.\n",
"A screenshot of this document has a much higher density of letters\n",
"than a close-up photograph of a piece of paper.\n",
"How do we get a _variable-length_ sequence of letters,\n",
"where the length need have nothing to do with the size of the input tensor?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n",
"they were\n",
"[originally introduced](https://arxiv.org/abs/1706.03762)\n",
"for transforming one sequence into another,\n",
"as in machine translation.\n",
"This makes them a natural fit for processing language.\n",
"\n",
"But they have also found success in other domains --\n",
"at the time of this writing, large transformers\n",
"dominate the\n",
"[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n",
"that has become a de facto standard for comparing models\n",
"and are finding\n",
"[application in reinforcement learning](https://arxiv.org/abs/2106.01345)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we will use a Transformer as a key component of our final architecture:\n",
"we will encode our input images with a CNN\n",
"and then read them out into a text sequence with a Transformer.\n",
"\n",
"Before trying out this new model,\n",
"let's first get an understanding of why the Transformer architecture\n",
"has become so popular by walking through its history\n",
"and then get some intuition for how it works\n",
"by looking at some\n",
"[recent work](https://transformer-circuits.pub/)\n",
"on explaining the behavior of both toy models and state-of-the-art language models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kmKqjbvd-Mj3"
},
"source": [
"## Why not convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRqkUMdM-OxU"
},
"source": [
"In the ancient beforetimes (i.e. 2016),\n",
"the best models for natural language processing were all\n",
"_recurrent_ neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional networks were also occasionally used,\n",
"but they suffered from a serious issue:\n",
"their architectural biases don't fit text.\n",
"\n",
"First, _translation equivariance_ no longer holds.\n",
"The beginning of a piece of text is often quite different from the middle,\n",
"so the absolute position matters.\n",
"\n",
"Second, _locality_ is not as important in language.\n",
"The name of a character that hasn't appeared in thousands of pages\n",
"can become salient when someone asks, \"Whatever happened to\n",
"[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n",
"\n",
"Consider interpreting a piece of text like the Python code below:\n",
"```python\n",
"def do(arg1, arg2, arg3):\n",
" a = arg1 + arg2\n",
" b = arg3[:3]\n",
" c = a * b\n",
" return c\n",
"\n",
"print(do(1, 1, \"ayy lmao\"))\n",
"```\n",
"\n",
"After a `(` we expect a `)`,\n",
"but possibly very long afterwards,\n",
"[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n",
"and similarly we expect a `]` at some point after a `[`.\n",
"\n",
"For translation variance, consider\n",
"that we interpret `*` not by\n",
"comparing it to its neighbors\n",
"but by looking at `a` and `b`.\n",
"We mix knowledge learned through experience\n",
"with new facts learned while reading --\n",
"also known as _in-context learning_.\n",
"\n",
"In a longer text,\n",
"[e.g. the one you are reading now](./lab03_transformers.ipynb),\n",
"the translation variance of text is clearer.\n",
"Every lab notebook begins with the same header,\n",
"setting up the environment,\n",
"but that header never appears elsewhere in the notebook.\n",
"Later positions need to be processed in terms of the previous entries.\n",
"\n",
"Unlike an image, we cannot simply rotate or translate our \"camera\"\n",
"and get a new valid text.\n",
"[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n",
"that can be read without regard to position."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The field of formal language theory,\n",
"which has deep mutual influence with computer science,\n",
"gives one way of explaining the issues with convolutional networks:\n",
"they can only understand languages with _finite contexts_,\n",
"where all the information can be found within a finite window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The immediate solution, drawing from the connections to computer science, is\n",
"[recursion](https://www.google.com/search?q=recursion).\n",
"A network whose output on the final entry of the sequence is a recursive function\n",
"of all the previous entries can build up knowledge\n",
"as it reads the sequence and treat early entries quite differently than it does late ones."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa6cbTlImkEh"
},
"source": [
"In pseudo-code, such a _recurrent neural network_ module might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lKtBoPnglPrW"
},
"source": [
"```python\n",
"def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n",
" next_inputs = input_module(xs[-1])\n",
" next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n",
" return output_module(next_inputs, next_hiddens)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbJPSMnEm516"
},
"source": [
"If you've had formal computer science training,\n",
"then you may be familiar with the power of recursion,\n",
"e.g. the\n",
"[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n",
"that gave its name to the now much better-known\n",
"[startup incubator](https://www.ycombinator.com/).\n",
"\n",
"The particular form of recursion used by\n",
"recurrent neural networks implements a\n",
"[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n",
"\n",
"> If you've know a lot of computer science,\n",
"you might be concerned by this connection.\n",
"What about other\n",
"[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n",
"Where are the neural network architectures for differentiable\n",
"[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n",
"Check out Graph Neural Networks,\n",
"[which implement dynamic programming](https://arxiv.org/abs/2203.15544)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "63mMTbEBpVuE"
},
"source": [
"Recurrent networks are able to achieve\n",
"[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n",
"\n",
"There are many popular recurrent architectures,\n",
"from the beefy and classic\n",
"[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n",
"and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n",
"([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n",
"all of which have roughly similar capabilities but\n",
"[some of which are easier to train](https://arxiv.org/abs/1611.09913)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwQHVTIslOku"
},
"source": [
"In the same sense that MLPs can model \"any\" feedforward function,\n",
"in principle even basic RNNs\n",
"[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n",
"\n",
"In particular they can model any\n",
"[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n",
"which is a formal way of saying that they can in principle\n",
"do anything a computer is capable of doing.\n",
"\n",
"The question is then..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3J8EoGN3pu7P"
},
"source": [
"## Why aren't we all using RNNs?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDwNWaevpt_3"
},
"source": [
"The guarantees that MLPs can model any function\n",
"or that RNNs can model Turing machines\n",
"provide decent intuition but are not directly practically useful.\n",
"Among other reasons, they don't guarantee learnability --\n",
"that starting from random parameters we can find the parameters\n",
"that implement a given function.\n",
"The\n",
"[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n",
"than would seem from basic theoretical and empirical analysis.\n",
"\n",
"One way of understanding capacity to model language is\n",
"[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n",
"In this model of formal languages,\n",
"Turing machines sit at the top\n",
"([practically speaking](https://arxiv.org/abs/math/0209332)).\n",
"\n",
"With better mathematical models,\n",
"RNNs and LSTMs can be shown to be\n",
"[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n",
"with RNNs looking more like\n",
"[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n",
"and LSTMs coming in\n",
"[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n",
"\n",
"More controversially:\n",
"the Chomsky hierarchy is great for understanding syntax and grammar,\n",
"which makes it great for building parsers\n",
"and working with formal languages,\n",
"but the goal in _natural_ language processing is to understand _natural_ language.\n",
"Most humans' natural language is far from strictly grammatical,\n",
"but that doesn't mean it is nonsense.\n",
"\n",
"And to really \"understand\" language means\n",
"to understand its semantic content, which is fuzzy.\n",
"The most important thing for handling the fuzzy semantic content\n",
"of language is not whether you can recall\n",
"[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n",
"but whether you can model probabilistic relationships between concepts\n",
"in addition to grammar and syntax."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These both leave theoretical room for improvement over current recurrent\n",
"language and sequence models.\n",
"\n",
"But the real cause of the rise of Transformers is that..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dsu1ebvAp-3Z"
},
"source": [
"## Transformers are designed to train fast at scale on contemporary hardware."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4abU5adsPGs"
},
"source": [
"The Transformer architecture has several important features,\n",
"discussed below,\n",
"but one of the most important reasons why it is successful\n",
"is because it can be more easily trained at scale.\n",
"\n",
"This scalability is the focus of the discussion in the paper\n",
"that introduced the architecture,\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"and\n",
"[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n",
"\n",
"The recursion in RNNs is inherently sequential:\n",
"the dependence on the outputs from earlier in the sequence\n",
"means computations within an example cannot be parallelized.\n",
"\n",
"So RNNs must batch across examples to scale,\n",
"but as sequence length grows this hits memorybandwidth limits.\n",
"Serving up large batches quickly with good randomness guarantees\n",
"is also hard to optimize,\n",
"especially in distributed settings.\n",
"\n",
"The Transformer architecture,\n",
"on the other hand,\n",
"can be readily parallelized within a single example sequence,\n",
"in addition to parallelization across batches.\n",
"This can lead to massive performance gains for a fixed scale,\n",
"which means larger, higher capacity models\n",
"can be trained on larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Mzk2haFC_G1"
},
"source": [
"How does the architecture achieve this parallelizability?\n",
"\n",
"Let's start with the architecture diagram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u59eu4snLQfp"
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n",
"\n",
"display.Image(url=base_url + \"/aiayn-figure-1.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ez-XEQ7M0UlR"
},
"source": [
"> To head off a bit of confusion\n",
" in case you've worked with Transformer architectures before:\n",
" the original \"Transformer\" is an encoder/decoder architecture.\n",
" Many LLMs, like GPT models, are decoder only,\n",
" because this has turned out to scale well,\n",
" and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n",
" it's all text anyways.\n",
" We, however, will be using them across modalities,\n",
" so we need an explicit encoder,\n",
" as above. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ok4ksBi4vp89"
},
"source": [
"First focusing on the encoder (left):\n",
"the encoding at a given position is a function of all previous inputs.\n",
"But it is not a function of the previous _encodings_:\n",
"we produce the encodings \"all at once\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RPN7C-_OqzHP"
},
"source": [
"The decoder (right) does use previous \"outputs\" as its inputs,\n",
"but those outputs are not the vectors of layer activations\n",
"(aka embeddings)\n",
"that are produced by the network.\n",
"They are instead the processed outputs,\n",
"after a `softmax` and an `argmax`.\n",
"\n",
"We could obtain these outputs by processing the embeddings,\n",
"much like in a recurrent architecture.\n",
"In fact, that is one way that Transformers are run.\n",
"It's what happens in the `.forward` method\n",
"of the model we'll be training for character recognition:\n",
"`ResnetTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L5_2WMmtDnJn"
},
"source": [
"Let's look at that forward method\n",
"and connect it to the diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FR5pk4kEyCGg"
},
"outputs": [],
"source": [
"from text_recognizer.models import ResnetTransformer\n",
"\n",
"\n",
"ResnetTransformer.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-J5UFDoPzPbq"
},
"source": [
"`.encode` happens first -- that's the left side of diagram.\n",
"\n",
"The encoder can in principle be anything\n",
"that produces a sequence of fixed-length vectors,\n",
"but here it's\n",
"[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n",
"\n",
"Then we start iterating over the sequence\n",
"in the `for` loop.\n",
"\n",
"Focus on the first few lines of code.\n",
"We apply `.decode` (right side of diagram)\n",
"to the outputs so far.\n",
"\n",
"Once we have a new `output`, we apply `.argmax`\n",
"to turn the logits into a concrete prediction of\n",
"a particular token.\n",
"\n",
"This is added as the last output token\n",
"and then the loop happens again."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTcy8-rV1dHr"
},
"source": [
"Run this way, our model looks very much like a recurrent architecture:\n",
"we call the model on its own outputs\n",
"to generate the next value.\n",
"These types of models are also referred to as\n",
"[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n",
"because we predict (as we do in _regression_)\n",
"the next value based on our own (_auto_) output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But Transformers are designed to be _trained_ more scalably than RNNs,\n",
"not necessarily to _run inference_ more scalably,\n",
"and it's actually not the case that our model's `.forward` is called during training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCxMSAWmEKBt"
},
"source": [
"Let's look at what happens during training\n",
"by checking the `training_step`\n",
"of the `LightningModule`\n",
"we use to train our Transformer models,\n",
"the `TransformerLitModel`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o7q8N7P2w4H"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import TransformerLitModel\n",
"\n",
"TransformerLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VgNNOjvzC4y"
},
"source": [
"Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tz-6NGPR4dUr"
},
"source": [
"Let's look at `.teacher_forward`,\n",
"and in particular its type signature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ILc2oWET4i2Z"
},
"outputs": [],
"source": [
"TransformerLitModel.teacher_forward??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lf32lpgrDb__"
},
"source": [
"This is known as \"teacher forcing\".\n",
"The \"teacher\" signal is \"forcing\"\n",
"the model to behave as though\n",
"it got the answer right.\n",
"\n",
"[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n",
"It's more effective here\n",
"because the right teaching signal\n",
"for our network is the target data,\n",
"which we have access to during training,\n",
"whereas in an RNN the best teaching signal\n",
"would be the target embedding vector,\n",
"which we do not know.\n",
"\n",
"During inference, when we don't have access to the ground truth,\n",
"we revert to the autoregressive `.forward` method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This \"trick\" allows Transformer architectures to readily scale\n",
"up models to the parameter counts\n",
"[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAjqpJm9uUuU"
},
"source": [
"## Is there more to Transformers more than just a training trick?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kWCYXeHv7Qc9"
},
"source": [
"[Very](https://arxiv.org/abs/2005.14165),\n",
"[very](https://arxiv.org/abs/1909.08053),\n",
"[very](https://arxiv.org/abs/2205.01068)\n",
"large Transformer models have powered the most recent wave of exciting results in ML, like\n",
"[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n",
"\n",
"They are also the first machine learning models to have come anywhere close to\n",
"deserving the term _artificial intelligence_ --\n",
"a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is surprising because the models and their training procedure are\n",
"(relatively speaking)\n",
"pretty _simple_,\n",
"even if it doesn't feel that way on first pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The basic Transformer architecture is just a bunch of\n",
"dense matrix multiplications and non-linearities --\n",
"it's perhaps simpler than a convolutional architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And advances since the introduction of Transformers in 2017\n",
"have not in the main been made by\n",
"creating more sophisticated model architectures\n",
"but by increasing the scale of the base architecture,\n",
"or if anything making it simpler, as in\n",
"[GPT-type models](https://arxiv.org/abs/2005.14165),\n",
"which drop the encoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V1HQS9ey8GMc"
},
"source": [
"These models are also trained on very simple tasks:\n",
"most LLMs are just trying to predict the next element in the sequence,\n",
"given the previous elements --\n",
"a task simple enough that Claude Shannon,\n",
"father of information theory, was\n",
"[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n",
"\n",
"These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n",
"e.g. by scraping the web."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They are also trained in a simple fashion:\n",
"first-order stochastic optimizers, like SGD or an\n",
"[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n",
"intended for the most basic of optimization problems,\n",
"that scale more readily than the second-order optimizers\n",
"that dominate other areas of optimization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kz9HPDoy7OAl"
},
"source": [
"This is\n",
"[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n",
"of work in ML:\n",
"simple, even seemingly wasteful,\n",
"architectures that scale well and are robust\n",
"to implementation details\n",
"eventually outstrip more clever but\n",
"also more finicky approaches that are harder to scale.\n",
"This lesson has led some to declare that\n",
"[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n",
"in machine learning, and perhaps even in artificial intelligence."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdN9o2Y771YZ"
},
"source": [
"> That is not to say that because the algorithms are relatively simple,\n",
" training a model at this scale is _easy_ --\n",
" [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n",
" [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n",
" [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n",
" But choosing the simplest algorithm at every step makes solving the scaling problem feasible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baVGf6gKFOvs"
},
"source": [
"The importance of scale is the key lesson from the Transformer architecture,\n",
"far more than any theoretical considerations\n",
"or any of the implementation details.\n",
"\n",
"That said, these large Transformer models are capable of\n",
"impressive behaviors and understanding how they achieve them\n",
"is of intellectual interest.\n",
"Furthermore, like any architecture,\n",
"there are common failure modes,\n",
"of the model and of the modelers who use them,\n",
"that need to be taken into account."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"Below, we'll cover two key intuitions about Transformers:\n",
"Transformers are _residual_, like ResNets,\n",
"and they compose _low rank_ sequence transformations.\n",
"Together, this means they act somewhat like a computer,\n",
"reading from and writing to a \"tape\" or memory\n",
"with a sequence of simple instructions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"We'll also cover a surprising implementation detail:\n",
"despite being commonly used for sequence modeling,\n",
"by default the architecture is _position insensitive_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uni0VTCr9lev"
},
"source": [
"### Intuition #1: Transformers are highly residual."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0MoBt-JLJz-d"
},
"source": [
"> The discussion of these inuitions summarizes the discussion in\n",
"[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n",
"from\n",
"[Anthropic](https://www.anthropic.com/),\n",
"an AI safety and research company.\n",
"The figures below are from that blog post.\n",
"It is the spiritual successor to the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"covered in\n",
"[Lab 02b](https://lab02b-colab).\n",
"If you want to truly understand Transformers,\n",
"we highly recommend you check it out,\n",
"including the\n",
"[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UUbNVvM5Ferm"
},
"source": [
"It's easy to see that ResNets are residual --\n",
"it's in the name, after all.\n",
"\n",
"But Transformers are,\n",
"in some sense,\n",
"even more closely tied to residual computation\n",
"than are ResNets:\n",
"ResNets and related architectures include downsampling,\n",
"so there is not a direct path from inputs to outputs.\n",
"\n",
"In Transformers, the exact same shape is maintained\n",
"from the moment tokens are embedded,\n",
"through dozens or hundreds of intermediate layers,\n",
"and until they are \"unembedded\" into class logits.\n",
"The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n",
"\n",
"The resiudal stream is easy to see with a change of perspective.\n",
"Instead of the usual architecture diagram above,\n",
"which emphasizes the layers acting on the tensors,\n",
"consider this alternative view,\n",
"which emphasizes the tensors as they pass through the layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRMlVguKKW6y"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-residual-view.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9K3N7ilVkB3"
},
"source": [
"For definitions of variables and terms, see the\n",
"[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arvciE-kKd_L"
},
"source": [
"Note that this is a _decoder-only_ Transformer architecture --\n",
"so it should be compared with the right-hand side of the original architecture diagram above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvrRMd_RKp_G"
},
"source": [
"Notice that outputs of the attention blocks \n",
"and of the MLP layers are\n",
"added to their inputs, as in a ResNet.\n",
"These operations are represented as \"Add & Norm\" layers in the classical diagram;\n",
"normalization is ignored here for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8n_iT-FFAbK"
},
"source": [
"This total commitment to residual operations\n",
"means the size of the embeddings\n",
"(referred to as the \"model dimension\" or the \"embedding dimension\",\n",
"here and below `d_model`)\n",
"stays the same throughout the entire network.\n",
"\n",
"That means, for example,\n",
"that the output of each layer can be used as input to the \"unembedding\" layer\n",
"that produces logits.\n",
"We can read out the computations of intermediate layers\n",
"just by passing them through the unembedding layer\n",
"and examining the logit tensor.\n",
"See\n",
"[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n",
"for detailed experiments and interactive notebooks.\n",
"\n",
"In short, we observe a sort of \"progressive refinement\"\n",
"of the next-token prediction\n",
"as the embeddings proceed, depthwise, through the network."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ovh_3YgY9z2h"
},
"source": [
"### Intuition #2 Transformer heads learn low rank transformations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpNmozlnOdPC"
},
"source": [
"In the original paper and in\n",
"most presentations of Transformers,\n",
"the attention layer is written like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PA7me8gNP5LE"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pseudo-typed PyTorch (based loosely on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n",
"that looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oeict_6wGJgD"
},
"source": [
"```python\n",
"def classic_attention(\n",
" Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" K: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" return torch.softmax(Q @ K.T) @ V\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8pewU90DSuOR"
},
"source": [
"This is effectively exactly\n",
"how it is written\n",
"in PyTorch,\n",
"apart from implementation details\n",
"(look for `bmm` for the matrix multiplications and a `softmax` call):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrgTpKFvOhwc"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"F._scaled_dot_product_attention??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebDXZ0tlSe7g"
},
"source": [
"But the best way to write an operation so that a computer can execute it quickly\n",
"is not necessarily the best way to write it so that a human can understand it --\n",
"otherwise we'd all be coding in assembly.\n",
"\n",
"And this is a strange way to write it --\n",
"you'll notice that what we normally think of\n",
"as the \"inputs\" to the layer are not shown.\n",
"\n",
"We can instead write out the attention layer\n",
"as a function of the inputs $x$.\n",
"We write it for a single \"attention head\".\n",
"Each attention layer includes a number of heads\n",
"that read and write from the residual stream\n",
"simultaneously and independently.\n",
"We also add the output layer weights $W_O$\n",
"and we get:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LuFNR67tQpsf"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVnBjjfOLwxP"
},
"source": [
"or, in pseudo-typed PyTorch:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmpOm-HfGaNz"
},
"source": [
"```python\n",
"def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n",
" key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n",
" key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n",
" # maps queries of residual stream to keys from residual stream, independent of position\n",
"\n",
" value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n",
" output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n",
" value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n",
" # transformation applied to each token, regardless of position\n",
"\n",
" attention_logits = x.T @ key_query_circuit @ x\n",
" attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n",
" # maps positions to positions, often very sparse\n",
"\n",
" value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n",
"\n",
" return attention_map @ value_output # transformed tokens filtered by attention map\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dC0eqxZ6UAGT"
},
"source": [
"Consider the `key_query_circuit`\n",
"and `value_output_circuit`\n",
"matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n",
"\n",
"The key/query dimension, `d_head`\n",
"is small relative to the model's dimension, `d_model`,\n",
"so $W_{QK}$ and $W_{OV}$ are very low rank,\n",
"[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n",
"that they factorize into two matrices,\n",
"one with a smaller number of rows\n",
"and another with a smaller number of columns.\n",
"That number is called the _rank_.\n",
"\n",
"When computing, these matrices are better represented via their components,\n",
"rather than computed directly,\n",
"which leads to the normal implementation of attention.\n",
"\n",
"In a large language model,\n",
"the ratio of residual stream dimension, `d_model`, to\n",
"the dimension of a single head, `d_head`, is huge, often 100:1.\n",
"That means each query, key, and value computed at a position\n",
"is a fairly simple, low-dimensional feature of the residual stream at that position.\n",
"\n",
"For visual intuition,\n",
"we compare what a matrix with a rank 100th of full rank looks like,\n",
"relative to a full rank matrix of the same size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LUbojJMiW2C"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"\n",
"low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n",
"full_rank = torch.randn(100, 100)\n",
"plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n",
"plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqBst92-OVka"
},
"source": [
"The pattern in the first matrix is very simple,\n",
"relative to the pattern in the second matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkCGrs9EiVh4"
},
"source": [
"Another feature of low rank transformations is\n",
"that they have a large nullspace or kernel --\n",
"these are directions we can move the input without changing the output.\n",
"\n",
"That means that many changes to the residual stream won't affect the behavior of this head at all."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVz2dQgzhD4p"
},
"source": [
"### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVlzwR03m8mC"
},
"source": [
"The combination of residuality\n",
"(changes are added to the current input)\n",
"and low rank\n",
"(only a small subspace is changed by each head)\n",
"drastically changes the intuition about Transformers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qqjZI2jKe6HH"
},
"source": [
"Rather than being an \"embedding of a token in its context\",\n",
"the residual stream becomes something more like a memory or a scratchpad:\n",
"one layer reads a small bit of information from the stream\n",
"and writes a small bit of information back to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YIBkxlqepjc"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-layer-residual.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RtsKhkLfk00l"
},
"source": [
"The residual stream works like a memory because it is roomy enough\n",
"that these actions need not interfere:\n",
"the subspaces targeted by reads and writes are small relative to the ambient space,\n",
"so they can\n",
"\n",
"Additionally, the dimension of each head is still in the 100s in large models,\n",
"and\n",
"[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n",
"in them, so the number of effectively degrees of freedom is\n",
"actually larger than the dimension.\n",
"This phenomenon allows high-dimensional tensors to serve as\n",
"[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n",
"There are\n",
"[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n",
"\n",
"Together, this means an early layer can write information to the stream\n",
"that can be used by later layers -- by many of them at once, possibly much later.\n",
"Later layers can learn to edit this information,\n",
"e.g. deleting it,\n",
"if doing so reduces the loss,\n",
"but by default the information is preserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EragIygzJg86"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-stream-read-write.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKIaUZjwkpW7"
},
"source": [
"Lastly, the softmax in the attention has a sparsifying effect,\n",
"and so many attention heads are reading from \n",
"just one token and writing to just one other token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dN6VcJqIMKnB"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-token-to-token.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Repeatedly reading information from an external memory\n",
"and using it to decide which operation to perform\n",
"and where to write the results\n",
"is at the core of the\n",
"[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n",
"For a concrete example, the\n",
"[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n",
"includes a dissection of a form of \"pointer arithmetic\"\n",
"that appears in some models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kLFh7Mvnolr"
},
"source": [
"This point of view seems\n",
"very promising for explaining numerous\n",
"otherwise perhaps counterintuitive features of Transformer models.\n",
"\n",
"- This framework predicts lots that Transformers will readily copy-and-paste information,\n",
"which might explain phenomena like\n",
"[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n",
"\n",
"- It also readily explains\n",
"[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n",
"an important component of why Transformers perform well on medium-length texts\n",
"and in few-shot learning.\n",
"\n",
"- Transformers also perform better on reasoning tasks when the text\n",
"[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n",
"is added to their input prompt.\n",
"This is partly due to the fact that that prompt is associated,\n",
"in the dataset, with clearer reasoning,\n",
"and since the models are trained to predict which tokens tend to appear\n",
"after an input, they tend to produce better reasoning with that prompt --\n",
"an explanation purely in terms of sequence modeling.\n",
"But it also gives the Transformer license to generate a large number of tokens\n",
"that act to store intermediate information,\n",
"making for a richer residual stream\n",
"for reading and writing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RyLRzgG-93yB"
},
"source": [
"### Implementation detail: Transformers are position-insensitive by default."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oR6PnrlA_hJ2"
},
"source": [
"In the attention calculation\n",
"each token can query each other token,\n",
"with no regard for order.\n",
"Furthermore, the construction of queries, keys, and values\n",
"is based on the content of the embedding vector,\n",
"which does not automatically include its position.\n",
"\"dog bites man\" and \"man bites dog\" are identical, as in\n",
"[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n",
"\n",
"For most sequences,\n",
"this is unacceptable:\n",
"absolute and relative position matter\n",
"and we cannot use the future to predict the past.\n",
"\n",
"We need to add two pieces to get a Transformer architecture that's usable for next-token prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWHxGJz2-6ZK"
},
"source": [
"First, the simpler piece:\n",
"\"causal\" attention,\n",
"so-named because it ensures that values earlier in the sequence\n",
"are not influenced by later values, which would\n",
"[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0c42xi6URYB4"
},
"source": [
"The most common solution is straightforward:\n",
"we calculate attention between all tokens,\n",
"then throw out non-causal values by \"masking\" them\n",
"(this is before applying the softmax,\n",
"so masking means adding $-\\infty$).\n",
"\n",
"This feels wasteful --\n",
"why are we calculating values we don't need?\n",
"Trying to be smarter would be harder,\n",
"and might rely on operations that aren't as optimized as\n",
"matrix multiplication and addition.\n",
"Furthermore, it's \"only\" twice as many operations,\n",
"so it doesn't even show up in $O$-notation.\n",
"\n",
"A sample attention mask generated by our code base is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NXaWe6pT-9jV"
},
"outputs": [],
"source": [
"from text_recognizer.models import transformer_util\n",
"\n",
"\n",
"attention_mask = transformer_util.generate_square_subsequent_mask(100)\n",
"\n",
"ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n",
"plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n",
"print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This solves our causality problem,\n",
"but we still don't have positional information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZamUE4WIoGS2"
},
"source": [
"The standard technique\n",
"is to add alternating sines and cosines\n",
"of increasing frequency to the embeddings\n",
"(there are\n",
"[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n",
"most notably\n",
"[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n",
"Each position in the sequence is then uniquely identifiable\n",
"from the pattern of these values.\n",
"\n",
"> Furthermore, for the same reason that\n",
" [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n",
" translations, e.g. relative positions, are fairly easy to express as linear transformations\n",
" of sines and cosines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IDG2uOsaELU0"
},
"source": [
"We superimpose this positional information on our embeddings.\n",
"Note that because the model is residual,\n",
"this position information will be by default preserved\n",
"as it passes through the network,\n",
"so it doesn't need to be repeatedly added."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what this positional encoding looks like in our codebase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Zk62Q-a-1Ax"
},
"outputs": [],
"source": [
"PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n",
"\n",
"pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n",
"\n",
"ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n",
"print(pe[:4, :8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep2ClIWvqDms"
},
"source": [
"When we add the positional information to our embeddings,\n",
"both the embedding information and the positional information\n",
"is approximately preserved,\n",
"as can be visually assessed below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJuFjoCzC0Y4"
},
"outputs": [],
"source": [
"fake_embeddings = torch.randn_like(pe) * 0.5\n",
"\n",
"ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n",
"\n",
"fake_embeddings_with_pe = fake_embeddings + pe\n",
"\n",
"plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHIzBxDkEmH8"
},
"source": [
"A [similar technique](https://arxiv.org/abs/2103.06450)\n",
"is used to also incorporate positional information into the image embeddings,\n",
"which are flattened before being fed to the decoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HC1N85wl8dvn"
},
"source": [
"### Learn more about Transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lJwYxkjTk15t"
},
"source": [
"We're only able to give a flavor and an intuition for Transformers here.\n",
"\n",
"To improve your grasp on the nuts and bolts, check out the\n",
"[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n",
"which is surprisingly approachable,\n",
"as far as ML research papers go.\n",
"The\n",
"[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n",
"adds code and commentary to the original paper,\n",
"which makes it even more digestible.\n",
"For something even friendlier, check out the\n",
"[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n",
"by Jay Alammar, which has an accompanying\n",
"[video](https://youtu.be/-QH8fRhqFHM).\n",
"\n",
"Anthropic's work on\n",
"[Transformer Circuits](https://transformer-circuits.pub/),\n",
"summarized above, has some of the best material\n",
"for building theoretical understanding\n",
"and is still being updated with extensions and applications of the framework.\n",
"The\n",
"[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n",
"are a great aid for checking and building your understanding.\n",
"\n",
"But they are fairly math-heavy.\n",
"If you have more of a software engineering background, see\n",
"Transformer Circuits co-author Nelson Elhage's blog post\n",
"[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n",
"\n",
"For a gentler introduction to the intuition for Transformers,\n",
"check out Brandon Rohrer's\n",
"[Transformers From Scratch](https://e2eml.school/transformers.html)\n",
"tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qg7zntJES-aT"
},
"source": [
"An aside:\n",
"the matrix multiplications inside attention dominate\n",
"the big-$O$ runtime of Transformers.\n",
"So trying to make the attention mechanism more efficient, e.g. linear time,\n",
"has generated a lot of research\n",
"(review paper\n",
"[here](https://arxiv.org/abs/2009.06732)).\n",
"Despite drawing a lot of attention, so to speak,\n",
"at the time of writing in mid-2022, these methods\n",
"[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n",
"so it isn't likely to be worth the effort to spend time learning about them\n",
"unless you are a Transformer specialist."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCjXysEJ8g9_"
},
"source": [
"# Using Transformers to read paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsfKWnOvqjva"
},
"source": [
"Our simple convolutional model for text recognition from\n",
"[Lab 02b](https://fsdl.me/lab02b-colab)\n",
"could only handle cleanly-separated characters.\n",
"\n",
"It worked by sliding a LeNet-style CNN\n",
"over the image,\n",
"predicting a character for each step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "njLdzBqy-I90"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist_lines = text_recognizer.data.EMNISTLines()\n",
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"\n",
"# for sliding, see the for loop over range(S)\n",
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0N6yDBQq8ns"
},
"source": [
"But unfortunately for us, handwritten text\n",
"doesn't come in neatly-separated characters\n",
"of equal size, so we trained our model on synthetic data\n",
"designed to work with that model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiqUVbj0sxLr"
},
"source": [
"Now that we have a better model,\n",
"we can work with better data:\n",
"paragraphs from the\n",
"[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oizsOAcKs-dD"
},
"source": [
"The cell uses our `LightningDataModule`\n",
"to download and preprocess this data,\n",
"writing results to disk.\n",
"We can then spin up `DataLoader`s to give us batches.\n",
"\n",
"It can take several minutes to run the first time\n",
"on commodity machines,\n",
"with most time spent extracting the data.\n",
"On subsequent runs,\n",
"the time-consuming operations will not be repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uL9LHbjdsUbm"
},
"outputs": [],
"source": [
"iam_paragraphs = text_recognizer.data.IAMParagraphs()\n",
"\n",
"iam_paragraphs.prepare_data()\n",
"iam_paragraphs.setup()\n",
"xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n",
"\n",
"iam_paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBkFN9bbTm_S"
},
"source": [
"Now that we've got a batch,\n",
"let's take a look at some samples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hqaps8yxtBhU"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import wandb\n",
"\n",
"\n",
"def show(y):\n",
" y = y.detach().cpu() # bring back from accelerator if it's being used\n",
" return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 04: Experiment Management"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How experiment management brings observability to ML model development\n",
"- Which features of experiment management we use in developing the Text Recognizer\n",
"- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 4\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lab contains a large number of embedded iframes\n",
"that benefit from having a wide window.\n",
"The cell below makes the notebook as wide as your browser window\n",
"if `full_width` is set to `True`.\n",
"Full width is the default behavior in Colab,\n",
"so this cell is intended to improve the viewing experience in other Jupyter environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Why experiment management?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To understand why we need experiment management for ML development,\n",
"let's start by running an experiment.\n",
"\n",
"We'll train a new model on a new dataset,\n",
"using the training script `training/run_experiment.py`\n",
"introduced in [Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use a CNN encoder and Transformer decoder, as in\n",
"[Lab 03](https://fsdl.me/lab03-colab),\n",
"but with some changes so we can iterate faster.\n",
"We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n",
"[Lab02b](https://fsdl.me/lab02b-colab),\n",
"and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n",
"from text_recognizer.data import IAMLines # processed version split into individual lines\n",
"from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n",
"\n",
"\n",
"print(IAM.__doc__)\n",
"\n",
"# uncomment a line below for details on either class\n",
"# IAMLines?? \n",
"# LineCNNTransformer??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below will train a model on 10% of the data for two epochs.\n",
"\n",
"It takes up to a few minutes to run on commodity hardware,\n",
"including data download and preprocessing.\n",
"As it's running, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"import torch\n",
"\n",
"\n",
"gpus = int(torch.cuda.is_available()) \n",
"\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the model trains, we're calculating lots of metrics --\n",
"loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n",
"and reporting them to the terminal.\n",
"\n",
"This is achieved by the built-in `.log` method\n",
"([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n",
"of the `LightningModule`,\n",
"and it is a very straightforward way to get basic information about your experiment as it's running\n",
"without leaving the context where you're running it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Learning to read\n",
"[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n",
"is something of a rite of passage for MLEs, but\n",
"let's consider what we can't see here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We're missing all metric values except the most recent --\n",
"we can see them as they stream in, but they're constantly overwritten.\n",
"We also can't associate them with timestamps, steps, or epochs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We also don't see any system metrics.\n",
"We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n",
"without launching a separate process.\n",
"And even if we do, those values will also not be saved and timestamped,\n",
"so we can't correlate them with other things during training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- As we continue to run experiments, changing code and opening new terminals,\n",
"even the information we have or could figure out now will disappear.\n",
"Say you spot a weird error message during training,\n",
"but your session ends and the stdout is gone,\n",
"so you don't know exactly what it was.\n",
"Can you recreate the error?\n",
"Which git branch and commit were you on?\n",
"Did you have any uncommitted changes? Which arguments did you pass?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Also, model checkpoints containing the parameter values have been saved to disk.\n",
"Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n",
"As we run more and more experiments,\n",
"we'll want to slice and dice them to see if,\n",
"say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to save and log all of this information, and more, in order to make our model training\n",
"[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n",
"in short, so that we can understand, make decisions about, and debug our model training\n",
"by looking at logs and source code, without having to recreate it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n",
"1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n",
"2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n",
"3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n",
"4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local Experiment Tracking with Tensorboard"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n",
"and it makes logging very easy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n",
"\n",
"It can also use a logger to push information elsewhere.\n",
"\n",
"By default, we use\n",
"[TensorBoard](https://www.tensorflow.org/tensorboard)\n",
"via the Lightning `TensorBoardLogger`,\n",
"which has been saving results to the local disk.\n",
"\n",
"Let's find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest experiment's directory\n",
"# by hand, you can just copy and paste it from the terminal\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n",
"filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n",
"latest_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!ls -lh {latest_log}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, we need to launch a TensorBoard server --\n",
"much like we need to launch a Jupyter server to use Jupyter notebooks.\n",
"\n",
"The cells below load an extension that lets you use TensorBoard inside of a notebook\n",
"the same way you'd use it from the command line, and then launch it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n",
"\n",
"port = 11717 # pick an open port on your machine\n",
"host = \"0.0.0.0\" # allow connections from the internet\n",
" # watch out! make sure you turn TensorBoard off\n",
"\n",
"%tensorboard --logdir {latest_log} --port {port} --host {host}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see some charts of metrics over time along with some charting controls.\n",
"\n",
"You can click around in this interface and explore it if you'd like,\n",
"but in the next section, we'll see that there are better tools for experiment management."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've run many experiments on this machine,\n",
"you can see all of their results by pointing TensorBoard\n",
"at the whole `lightning_logs` directory,\n",
"rather than just one experiment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For large numbers of experiments, the management experience is not great --\n",
"it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n",
"\n",
"It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n",
"which are important workflows as applications mature and teams grow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n",
"\n",
"If you run into any issues with the above cells failing to launch,\n",
"especially across iterations of this lab, run this cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorboard.manager\n",
"\n",
"# get the process IDs for all tensorboard instances\n",
"pids = [tb.pid for tb in tensorboard.manager.get_all()]\n",
"\n",
"done_with_tensorboard = False\n",
"\n",
"if done_with_tensorboard:\n",
" # kill processes\n",
" for pid in pids:\n",
" !kill {pid} 2> /dev/null\n",
" \n",
" # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n",
" !rm -rf {tensorboard.manager._get_info_dir()}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiment Management with Weights & Biases"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How do we manage experiments when we hit the limits of local TensorBoard?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is powerful and flexible and very scalable,\n",
"but running it requires engineering effort and babysitting --\n",
"you're running a database, writing data to it,\n",
"and layering a web application over it.\n",
"\n",
"This is a fairly common workflow for web developers,\n",
"but not so much for ML engineers.\n",
"\n",
"You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n",
"and it's as simple as running the command `tensorboard dev upload`\n",
"pointed at your logging directory.\n",
"\n",
"But there are strict limits to this free service:\n",
"1GB of tensor data and 1GB of binary data.\n",
"A single Text Recognizer model checkpoint is ~100MB,\n",
"and that's not particularly large for a useful model.\n",
"\n",
"Furthermore, all data is public,\n",
"so if you upload the inputs and outputs of your model,\n",
"anyone who finds the link can see them.\n",
"\n",
"Overall, tensorboard.dev works very well for certain academic and open projects\n",
"but not for industrial ML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To avoid that narrow permissions and limits issue,\n",
"you could use [git LFS](https://git-lfs.github.com/)\n",
"to track the binary data and tensor data,\n",
"which is more likely to be sensitive than metrics.\n",
"\n",
"The Hugging Face ecosystem uses TensorBoard and git LFS.\n",
"\n",
"It includes the Hugging Face Hub, a git server much like GitHub,\n",
"but designed first and foremost for collaboration on models and datasets,\n",
"rather than collaboration on code.\n",
"For example, the Hugging Face Hub\n",
"[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n",
"and officially has\n",
"[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n",
"avoiding the\n",
"[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n",
"that make using git LFS with GitHub expensive.\n",
"\n",
"However, we prefer to avoid mixing software version control and experiment management.\n",
"\n",
"First, using the Hub requires maintaining an additional git remote,\n",
"which is a hard ask for many engineering teams.\n",
"\n",
"Secondly, git-style versioning is an awkward fit for logging --\n",
"is it really sensible to create a new commit for each logging event while you're watching live?\n",
"\n",
"Instead, we prefer to use systems that solve experiment management with _databases_."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n",
"The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n",
"tool is [MLflow](https://github.com/mlflow/mlflow/)\n",
"and there are a number of\n",
"[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n",
"\n",
"These tools generally avoid any need to worry about hosting\n",
"(unless data governance rules require a self-hosted version).\n",
"\n",
"For a sampling of publicly-posted opinions on experiment management tools,\n",
"see these discussions from Reddit:\n",
"\n",
"- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n",
"- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n",
"\n",
"Among these tools, the FSDL recommendation is\n",
"[Weights & Biases](https://wandb.ai),\n",
"which we believe offers\n",
"- the best user experience, both in the Python SDKs and in the graphical interface\n",
"- the best integrations with other tools,\n",
"including\n",
"[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n",
"[Keras](https://docs.wandb.ai/guides/integrations/keras),\n",
"[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n",
"and even\n",
"[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n",
"and\n",
"- the best tools for collaboration.\n",
"\n",
"Below, we'll take care to point out which logging and management features\n",
"are available via generic interfaces in Lightning and which are W&B-specific."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"print(wandb.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding it to our experiment running code is extremely easy,\n",
"relative to the features we get, which is\n",
"one of the main selling points of W&B.\n",
"\n",
"We get most of our new experiment management features just by changing a single variable, `logger`, from\n",
"`TensorboardLogger` to `WandbLogger`\n",
"and adding two lines of code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll see what each of these lines does for us below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this logger is built into and maintained by PyTorch Lightning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"\n",
"WandbLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to complete the rest of this notebook,\n",
"you'll need a Weights & Biases account.\n",
"\n",
"As with GitHub the free tier, for personal, academic, and open source work,\n",
"is very generous.\n",
"\n",
"The Text Recognizer project will fit comfortably within the free tier.\n",
"\n",
"Run the cell below and follow the prompts to log in or create an account or go\n",
"[here](https://wandb.ai/signup)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the cell below to launch an experiment tracked with Weights & Biases.\n",
"\n",
"The experiment can take between 3 and 10 minutes to run.\n",
"In that time, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n",
" --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1\n",
" \n",
"last_expt = wandb.run\n",
"\n",
"wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see some new things in our output.\n",
"\n",
"For example, there's a note from `wandb` that the data is saved locally\n",
"and also synced to their servers.\n",
"\n",
"There's a link to a webpage for viewing the logged data and a name for our experiment --\n",
"something like `dandy-sunset-1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The local logging and cloud syncing happens with minimal impact on performance,\n",
"because `wandb` launches a separate process to listen for events and upload them.\n",
"\n",
"That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, head to the link in the notebook output\n",
"that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n",
"\n",
"There's no need to wait for training to finish.\n",
"\n",
"The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(last_expt.url)\n",
"IFrame(last_expt.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have landed on the run page\n",
"([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n",
"which collects up all of the information for a single experiment into a collection of tabs.\n",
"\n",
"We'll work through these tabs from top to bottom.\n",
"\n",
"Each header is also a link to the documentation for a tab."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n",
"This tab has an icon that looks like `(i)` or 🛈.\n",
"\n",
"The top section of this tab has high-level information about our run:\n",
"- Timing information, like start time and duration\n",
"- System hardware, hostname, and basic environment info\n",
"- Git repository link and state\n",
"\n",
"This information is collected and logged automatically.\n",
"\n",
"The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n",
"and summary metrics.\n",
"\n",
"Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n",
"\n",
"This tab has a line plot icon, something like 📈.\n",
"\n",
"It's also the default page you land on when looking at a W&B run.\n",
"\n",
"Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n",
"\n",
"Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n",
"\n",
"We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n",
"This tab has computer chip icon.\n",
"\n",
"It contains\n",
"- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n",
"- CPU metrics: memory usage, utilization, thread counts\n",
"- Disk and network I/O levels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n",
"This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n",
"\n",
"The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n",
"This tab has an icon that looks like a stylized command prompt, `>_`.\n",
"\n",
"It contains information that was printed to the stdout.\n",
"\n",
"This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n",
"\n",
"Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelSummary\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n",
"\n",
"For more on Lightning `Callback`s, see\n",
"[Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n",
"This tab has a stylized document icon, something like 📄.\n",
"\n",
"You can use this tab to view any files saved with the `wandb.save`.\n",
"\n",
"For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n",
"which we'll discuss shortly.\n",
"\n",
"But a few pieces of information automatically collected by W&B end up in this tab.\n",
"\n",
"Some highlights:\n",
" - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n",
" - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n",
"This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n",
"\n",
"This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n",
"\n",
"We store two kinds of binary files\n",
" - `run_table`s of model inputs and outputs\n",
" - `model` checkpoints\n",
"\n",
"We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interactive Tables of Logged Media"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Returning to the Charts tab,\n",
"notice that we have model inputs and outputs logged in structured tables\n",
"under the train, validation, and test sections.\n",
"\n",
"These tables are interactive as well\n",
"([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n",
"They support basic exploratory data analysis and are compatible with W&B's collaboration features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to charts in our run page, these tables also have their own pages inside the W&B web app."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n",
"table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n",
"\n",
"print(table_data_url)\n",
"IFrame(src=table_data_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Getting this to work requires more effort and more W&B-specific code\n",
"than the other features we've seen so far.\n",
"\n",
"We'll briefly explain the implementation here, for those who are interested.\n",
"\n",
"We use a custom Lightning `Callback`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n",
"\n",
"\n",
"ImageToTextTableLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n",
"\n",
"The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n",
"\n",
"This behavior is sensible for metrics, which are low overhead, but not so much for media,\n",
"where we'd rather subsample and avoid holding on to too much information.\n",
"\n",
"So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n",
"\n",
"The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.lit_models.base import BaseImageToTextLitModel\n",
"\n",
"BaseImageToTextLitModel.add_on_logged_batches??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Projects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Everything we've seen so far has been related to a single run or experiment.\n",
"\n",
"Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n",
"\n",
"We organize our runs into \"projects\" and view them on the W&B \"project page\" \n",
"([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n",
"\n",
"By default in the Lightning integration, the project name is determined based on directory information.\n",
"This default can be over-ridden in the code when creating a `WandbLogger`,\n",
"but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what the project page looks like for a longer-running project with lots of experiments.\n",
"\n",
"The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n",
"\n",
"print(project_url)\n",
"IFrame(src=project_url, width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n",
"\n",
"We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Artifacts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Beyond logging metrics and metadata from runs,\n",
"we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below pulls up all of the artifacts associated with the experiment we just ran."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Click on one of the `model` checkpoints -- the specific version doesn't matter.\n",
"\n",
"There are a number of tabs here.\n",
"\n",
"The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n",
"\n",
"The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n",
"which are added by default by the `WandbLogger`.\n",
"\n",
"The \"Files\" tab contains the actual file contents of the artifact.\n",
"\n",
"On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n",
"including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n",
"\n",
"You can click on these to explore the different versions and even directly compare them.\n",
"\n",
"If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Artifact storage is part of the W&B free tier.\n",
"\n",
"The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n",
"\n",
"The former is sufficient to store ~700 model checkpoints for the Text Recognizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can track your data storage and compare it to your limits at this URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n",
"\n",
"print(storage_tracker_url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Programmatic Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also programmatically access our data and metadata via the `wandb` API\n",
"([docs](https://docs.wandb.ai/guides/track/public-api-guide)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wb_api = wandb.Api()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run = wb_api.run(\"/\".join( # fetch a run given\n",
" [last_expt.entity, # the user or org it was logged to\n",
" last_expt.project, # the \"project\", usually one of several per repo/application\n",
" last_expt.id] # and a unique ID\n",
"))\n",
"\n",
"hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n",
"\n",
"hist.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hist.groupby(\"epoch\")[\"train/loss\"].mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this includes the artifacts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# which artifacts where created and logged?\n",
"artifacts = run.logged_artifacts()\n",
"\n",
"for artifact in artifacts:\n",
" print(f\"artifact of type {artifact.type}: {artifact.name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thanks to our `ImageToTextTableLogger`,\n",
"we can easily recreate training or validation data that came out of our `DataLoader`s,\n",
"which is normally ephemeral:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n",
"artifact_dir = Path(artifact.download(root=\"training/logs\"))\n",
"image_dir = artifact_dir / \"media\" / \"images\"\n",
"\n",
"images = [path for path in image_dir.iterdir()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"from IPython.display import Image\n",
"\n",
"Image(str(random.choice(images)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Advanced W&B API Usage: MLOps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the strengths of a well-instrumented experiment tracking system is that it allows\n",
"automatic relation of information:\n",
"what were the inputs when this model's gradient spiked?\n",
"Which models have been trained on this dataset,\n",
"and what was their performance?\n",
"\n",
"Having access and automation around this information is necessary for \"MLOps\",\n",
"which applies contemporary DevOps principles to ML projects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cells below pull down the training data\n",
"for the model currently running the FSDL Text Recognizer app.\n",
"\n",
"This is just intended as a demonstration of what's possible,\n",
"so don't worry about understanding every piece of this,\n",
"and feel free to skip past it.\n",
"\n",
"MLOps is still a nascent field, and these tools and workflows are likely to change.\n",
"\n",
"For example, just before the course launched, W&B released a\n",
"[Model Registry layer](https://docs.wandb.ai/guides/models)\n",
"on top of artifact logging that aims to improve the developer experience for these workflows."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start from the same project we looked at in the project view:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n",
"\n",
"text_recognizer_project "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we search it for the text recognizer model currently being used in production:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# collect all versions of the text-recognizer ever put into production by...\n",
"\n",
"for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n",
" if art_type.name == \"prod-ready\": # for the prod-ready type\n",
" # and grabbing the text-recognizer\n",
" production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n",
"\n",
"# and then get the one that's currently being tested in CI by...\n",
"for text_recognizer in production_text_recognizers:\n",
" if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n",
" in_prod_text_recognizer = text_recognizer\n",
"\n",
"# view its metadata at the url or in the notebook\n",
"in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n",
"\n",
"print(in_prod_text_recognizer_url)\n",
"IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From its metadata, we can get information about how it was \"staged\" to be put into production,\n",
"and in particular which model checkpoint was used:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staging_run = in_prod_text_recognizer.logged_by()\n",
"\n",
"training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n",
"training_ckpt.name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That checkpoint was logged by a training experiment, which is available as metadata.\n",
"\n",
"We can look at the training run for that model, either here in the notebook or at its URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"training_run = training_ckpt.logged_by()\n",
"print(training_run.url)\n",
"IFrame(src=training_run.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And from there, we can access logs and metadata about training,\n",
"confident that we are working with the model that is actually in production.\n",
"\n",
"For example, we can pull down the data we logged and analyze it locally."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_results = training_run.history(samples=10000)\n",
"training_results.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n",
"training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 10\n",
"training_results[\"validation/loss\"].dropna().iloc[10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The charts and webpages in Weights & Biases\n",
"are substantially more useful than ephemeral stdouts or raw logs on disk.\n",
"\n",
"If you're spun up on the project,\n",
"they accelerate debugging, exploration, and discovery.\n",
"\n",
"If not, they're not so much useful as they are overwhelming.\n",
"\n",
"We need to synthesize the raw logged data into information.\n",
"This helps us communicate our work with other stakeholders,\n",
"preserve knowledge and prevent repetition of work,\n",
"and surface insights faster.\n",
"\n",
"These workflows are supported by the W&B Reports feature\n",
"([docs here](https://docs.wandb.ai/guides/reports)),\n",
"which mix W&B charts and tables with explanatory markdown text and embeds.\n",
"\n",
"Below are some common report patterns and\n",
"use cases and examples of each."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of the examples are from the FSDL Text Recognizer project.\n",
"You can find more of them\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n",
"where we've organized them into a report!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dashboard Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dashboards are a structured subset of the output from one or more experiments,\n",
"designed for quickly surfacing issues or insights,\n",
"like an accuracy or performance regression\n",
"or a change in the data distribution.\n",
"\n",
"Use cases:\n",
"- show the basic state of ongoing experiment\n",
"- compare one experiment to another\n",
"- select the most important charts so you can spin back up into context on a project more quickly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n",
"\n",
"IFrame(src=dashboard_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pull Request Documentation Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In most software codebases,\n",
"pull requests are a key focal point\n",
"for units of work that combine\n",
"short-term communication and long-term information tracking.\n",
"\n",
"In ML codebases, it's more difficult to bring\n",
"sufficient information together to make PRs as useful.\n",
"At FSDL, we like to add documentary\n",
"reports with one or a small number of charts\n",
"that connect logged information in the experiment management system\n",
"to state in the version control software.\n",
"\n",
"Use cases:\n",
"- communication of results within a team, e.g. code review\n",
"- record-keeping that links pull request pages to raw logged info and makes it discoverable\n",
"- improving confidence in PR correctness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n",
"\n",
"IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Blog Post Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With sufficient effort, the logged data in the experiment management system\n",
"can be made clear enough to be consumed,\n",
"sufficiently contextualized to be useful outside the team, and\n",
"even beautiful.\n",
"\n",
"The result is a report that's closer to a blog post than a dashboard or internal document.\n",
"\n",
"Use cases:\n",
"- communication between teams or vertically in large organizations\n",
"- external technical communication for branding and recruiting\n",
"- attracting users or contributors\n",
"\n",
"Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n",
"[Boris Dayma](https://twitter.com/borisdayma)\n",
"and others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n",
"\n",
"IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperparameter Optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Many of our choices, like the depth of our network, the nonlinearities of our layers,\n",
"and the learning rate and other parameters of our optimizer, cannot be\n",
"([easily](https://arxiv.org/abs/1606.04474))\n",
"chosen by descent of the gradient of a loss function.\n",
"\n",
"But these parameters that impact the values of the parameters\n",
"we directly optimize with gradients, or _hyperparameters_,\n",
"can still be optimized,\n",
"essentially by trying options and selecting the values that worked best.\n",
"\n",
"In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n",
"\n",
"Expending more compute can squeeze small amounts of additional validation or test performance\n",
"that makes for impressive results on leaderboards but typically doesn't translate\n",
"into better user experience.\n",
"\n",
"In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n",
"built into your other tooling.\n",
"\n",
"Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n",
"([docs](https://docs.wandb.ai/guides/sweeps)).\n",
"\n",
"It also supports a number of more advanced tools, like\n",
"[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n",
"for early termination of poorly-performing runs.\n",
"\n",
"We can use the same training script and we don't need to run an optimization server.\n",
"\n",
"We just need to write a configuration yaml file\n",
"([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n",
"like the one below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile training/simple-overfit-sweep.yaml\n",
"# first we specify what we're sweeping\n",
"# we specify a program to run\n",
"program: training/run_experiment.py\n",
"# we optionally specify how to run it, including setting default arguments\n",
"command: \n",
" - ${env}\n",
" - ${interpreter}\n",
" - ${program}\n",
" - \"--wandb\"\n",
" - \"--overfit_batches\"\n",
" - \"1\"\n",
" - \"--log_every_n_steps\"\n",
" - \"25\"\n",
" - \"--max_epochs\"\n",
" - \"100\"\n",
" - \"--limit_test_batches\"\n",
" - \"0\"\n",
" - ${args} # these arguments come from the sweep parameters below\n",
"\n",
"# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n",
"method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n",
"metric:\n",
" name: train/loss\n",
" goal: minimize\n",
"parameters: \n",
" # LineCNN hyperparameters\n",
" window_width:\n",
" values: [8, 16, 32, 64]\n",
" window_stride:\n",
" values: [4, 8, 16, 32]\n",
" # Transformer hyperparameters\n",
" tf_layers:\n",
" values: [1, 2, 4, 8]\n",
" # we can also fix some values, just like we set default arguments\n",
" gpus:\n",
" value: 1\n",
" model_class:\n",
" value: LineCNNTransformer\n",
" data_class:\n",
" value: IAMLines\n",
" loss:\n",
" value: transformer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the config we launch a \"controller\":\n",
"a lightweight process that just decides what hyperparameters to try next\n",
"and coordinates the heavierweight training.\n",
"\n",
"This lives on the W&B servers, so there are no headaches about opening ports for communication,\n",
"cleaning up when it's done, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n",
"simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we can launch an \"agent\" to follow the orders of the controller:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# interrupt twice to terminate this cell if it's running too long,\n",
"# it can be over 15 minutes with some hyperparameters\n",
"\n",
"!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n",
"\n",
"If not provided, the agent will run forever for random or Bayesian sweeps\n",
"or until the sweep is terminated, which can be done from the W&B interface."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The agents make for a slick workflow for distributing sweeps across GPUs.\n",
"\n",
"We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n",
"which controls which GPUs are accessible by a process, to launch\n",
"parallel agents on separate GPUs on the same machine."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n",
"# open another terminal\n",
"CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n",
"# and so on\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We include optional exercises with the labs for learners who want to dive deeper on specific topics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟Contribute to a hyperparameter search."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n",
"\n",
"There are ~10,000,000 potential hyperparameter combinations,\n",
"and each takes 30 minutes to test,\n",
"so checking each possibility will take over 500 years of compute time.\n",
"Best get cracking then!\n",
"\n",
"Run the cell below to pull up a dashboard and print the URL where you can check on the current status."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_entity = \"fullstackdeeplearning\"\n",
"sweep_project = \"fsdl-line-recognizer-2022\"\n",
"sweep_id = \"e0eo43eu\"\n",
"sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n",
"\n",
"print(sweep_url)\n",
"IFrame(src=sweep_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also retrieve information about the sweep from the API,\n",
"including the hyperparameters being swept over."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hyperparams = sweep_info.config[\"parameters\"]\n",
"hyperparams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you'd like to contribute to this sweep,\n",
"run the cell below after changing the count to a number greater than 0.\n",
"\n",
"Each iteration runs for 30 minutes if it does not crash,\n",
"e.g. due to out-of-memory errors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"count = 0 # off by default, increase it to join in!\n",
"\n",
"if count:\n",
" !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Write some manual logging in `wandb`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the FSDL Text Recognizer codebase,\n",
"we almost exclusively log to W&B through Lightning,\n",
"rather than through the `wandb` Python SDK.\n",
"\n",
"If you're interested in learning how to use W&B directly, e.g. with another training framework,\n",
"try out this quick exercise that introduces the key players in the SDK."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n",
"\n",
"It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n",
"\n",
"Add W&B metric and artifact logging to this cell:\n",
"- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n",
"- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import os\n",
"import random\n",
"\n",
"import wandb\n",
"\n",
"\n",
"os.makedirs(\"logs\", exist_ok=True)\n",
"\n",
"project = \"trying-wandb\"\n",
"config = {\"steps\": 50}\n",
"\n",
"\n",
"with wandb.init(project=project, config=config) as run:\n",
" steps = wandb.config[\"steps\"]\n",
" \n",
" for ii in range(steps):\n",
" loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n",
" \n",
" with open(\"logs/hello.txt\", \"w\") as f:\n",
" f.write(\"hello from wandb, my dudes!\")\n",
" \n",
" run_id = run.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hello_run = wb_api.run(f\"{project}/{run_id}\")\n",
"\n",
"# check for logged loss data\n",
"if \"loss\" not in hello_run.history().keys():\n",
" print(\"loss not logged 🥲\")\n",
"else:\n",
" print(\"loss logged successfully 🥞\")\n",
" if len(hello_run.history()[\"loss\"]) != steps:\n",
" print(\"loss not logged on all steps 🥲\")\n",
" else:\n",
" print(\"loss logged on all steps 🥞\")\n",
"\n",
"artifacts = hello_run.logged_artifacts()\n",
"\n",
"# check for artifact with the right name\n",
"if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n",
" print(\"hello artifact not logged 🥲\")\n",
"else:\n",
" print(\"hello artifact logged successfully 🥞\")\n",
" # check for the file inside the artifacts\n",
" if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n",
" print(\"could not find hello.txt 🥲\")\n",
" else:\n",
" print(\"hello.txt logged successfully 🥞\")\n",
" \n",
" \n",
"hello_run"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you observe interesting phenomena during training,\n",
"from promising hyperparameter combos to software bugs to strange model behavior,\n",
"turn the charts into a W&B report and share it with the FSDL community or\n",
"[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n",
"with a link to them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# check the sweep_info.config above to see the model and data hyperparameters\n",
"# read through the --help output for all potential arguments\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n",
" --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n",
" --help # remove this line to run an experiment instead of printing help\n",
" \n",
"last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟🌟🌟 Add logging of tensor statistics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to logging model inputs and outputs as human-interpretable media,\n",
"it's also frequently useful to see information about their numerical values."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're interested in learning more about metric calculation and logging with Lightning,\n",
"use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n",
"to add tensor statistic logging to the `LineCNNTransformer`.\n",
"\n",
"`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n",
"\n",
"All three are useful, but start by adding just one."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n",
"- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n",
" - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n",
"- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n",
" - Base your code on the calculation and logging of the `val_cer` metric.\n",
" - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ",
"collapsed_sections": [],
"name": "lab04_experiments.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab05/notebooks/lab05_troubleshooting.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 05: Troubleshooting & Testing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n",
"- How to implement tests for ML training systems in particular\n",
"- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 5\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sThWeTtV6fL_"
},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xFP8lU4nSg1P"
},
"source": [
"# Linting Python and Shell Scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cXbdYfFlPhZ-"
},
"source": [
"### Automatically linting with `pre-commit`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ysqqb2GjvLrz"
},
"source": [
"We want keep our code clean and uniform across developers\n",
"and time.\n",
"\n",
"Applying the cleanliness checks and style rules should be\n",
"as painless and automatic as possible.\n",
"\n",
"For this purpose, we recommend bundling linting tools together\n",
"and enforcing them on all commits with\n",
"[`pre-commit`](https://pre-commit.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvqtZChKvLr0"
},
"source": [
"In addition to running on every commit,\n",
"`pre-commit` separates the model development environment from the environments\n",
"needed for the linting tools, preventing conflicts\n",
"and simplifying maintenance and onboarding."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0XuIuKOXhJl"
},
"source": [
"This cell runs `pre-commit`.\n",
"\n",
"The first time it is run on a machine, it will install the environments for all tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hltYGbpNvLr1"
},
"outputs": [],
"source": [
"!pre-commit run --all-files"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gLw08gIkvLr1"
},
"source": [
"The output lists all the checks that are run and whether they are passed.\n",
"\n",
"Notice there are a number of simple version-control hygiene practices included\n",
"that aren't even specific to Python, much less to machine learning.\n",
"\n",
"For example, several of the checks prevent accidental commits with private keys, large files, \n",
"leftover debugger statements, or merge conflict annotations in them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RHEEjb9kvLr1"
},
"source": [
"These linting actions are configured via\n",
"([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n",
"a YAML file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dgXa8BzrvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8HYc_WbTvLr2"
},
"source": [
"Most of the general cleanliness checks are from hooks built by `pre-commit`.\n",
"\n",
"See the comments and links in the `.pre-commit-config.yaml` for more:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K9rTgRqzvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep repos -A 15"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ptkO7aPvLr2"
},
"source": [
"Let's take a look at the section of the file\n",
"that applies most of our Python style enforcement with\n",
"[`flake8`](https://flake8.pycqa.org/en/latest/):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ALsRKfcevLr3",
"scrolled": true
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_Q0BwQUXbg6"
},
"source": [
"The majority of the style checking behavior we want comes from the\n",
"`additional_dependencies`, which are\n",
"[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n",
"that extend `flake8`'s list of lints.\n",
"\n",
"Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n",
"\n",
"We keep the configuration information for `flake8`\n",
"separate from that for `pre-commit`\n",
"in case we want to use additional tools with `flake8`,\n",
"e.g. if some developers want to integrate it directly into their editor,\n",
"and so that if we change away from `.pre-commit`\n",
"but keep `flake8` we don't have to\n",
"recreate our configuration in a different tool.\n",
"\n",
"As much as possible, codebases should strive for single sources of truth\n",
"and link back to those sources of truth with documentation or comments,\n",
"as in the last line above.\n",
"\n",
"Let's take a look at the contents of `flake8`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "doC_4WQwvLr3"
},
"outputs": [],
"source": [
"!cat .flake8"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Nq6HnyU0M47"
},
"source": [
"There's a lot here! We'll focus on the most important bits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U4PiB8CPvLr3"
},
"source": [
"Linting tools in Python generally work by emitting error codes\n",
"with one or more letters followed by three numbers.\n",
"The `select` argument picks which error codes we want to check for.\n",
"Error codes are matched by prefix,\n",
"so for example `B` matches `BTS101` and\n",
"`G1` matches `G102` and `G199` but not `ARG404`.\n",
"\n",
"Certain codes are `ignore`d in the default `flake8` style,\n",
"which is done via the `ignore` argument,\n",
"and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n",
"For example, we rely on `black` to do our formatting,\n",
"so we ignore some of `flake8`'s formatting codes.\n",
"\n",
"Together, these settings define our project's particular style.\n",
"\n",
"But not every file fits this style perfectly.\n",
"Most of the conventions in `black` and `flake8` come from the style-defining\n",
"[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n",
"which exhorts you to \"know when to be inconsistent\".\n",
"\n",
"To allow ourselves to be inconsistent when we know we should be,\n",
"`flake8` includes `per-file-ignores`,\n",
"which let us ignore specific warnings in specific files.\n",
"This is one of the \"escape valves\"\n",
"that makes style enforcement tolerable.\n",
"We can also `exclude` files in the `pre-commit` config itself.\n",
"\n",
"For details on selecting and ignoring,\n",
"see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n",
"\n",
"For definitions of the error codes from `flake8` itself,\n",
"see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n",
"Individual extensions list their added error codes in their documentation,\n",
"e.g. `darglint` does so\n",
"[here](https://github.com/terrencepreilly/darglint#error-codes)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NL0TpyPsvLr4"
},
"source": [
"The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n",
"\n",
"You can read more about each in their documentation:\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mFsZC0a7vLr4"
},
"source": [
"### Linting via a script and using `shellcheck`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RYjpuFwjXkJc"
},
"source": [
"To avoid needing to think about `pre-commit`\n",
"(was the command `pre-commit run` or `pre-commit check`?)\n",
"while developing locally,\n",
"we might put our linters into a shell script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mXlLFWmavLr4"
},
"outputs": [],
"source": [
"!cat tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PPxHpRIB3nbw"
},
"source": [
"These kinds of short and simple shell scripts are common in projects\n",
"of intermediate size.\n",
"\n",
"They are useful for adding automation and reducing friction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TMuPBpAi2qwl"
},
"source": [
"But these scripts are code,\n",
"and all code is susceptible to bugs and subject to concerns of style consistency."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SQRg3ZqXvLr4"
},
"source": [
"We can't check these scripts with tools that lint Python code,\n",
"so we include a shell script linting tool,\n",
"[`shellcheck`](https://www.shellcheck.net/),\n",
"in our `pre-commit`.\n",
"\n",
"More so than checking for correct style,\n",
"this tool checks for common bugs or surprising behaviors of shells,\n",
"which are unfortunately numerous."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zkfhE1srvLr4"
},
"outputs": [],
"source": [
"script_filename = \"tasks/lint.sh\"\n",
"!pre-commit run shellcheck --files {script_filename}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXU9TRrwvLr4"
},
"source": [
"That script has already been tested, so we don't see any errors.\n",
"\n",
"Try copying over a script you've written yourself or\n",
"even from a popular repo that you like\n",
"(by adding to the notebook directory or by making a cell\n",
"with `%%writefile` at the top)\n",
"and test it by changing the `script_filename`.\n",
"\n",
"You'd be surprised at the classes of subtle bugs possible in bash!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "81MhAL-TvLr5"
},
"source": [
"### Try \"unofficial bash strict mode\" for louder failures in scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hSwhs_zUvLr5"
},
"source": [
"Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n",
"[@redsymbol](https://twitter.com/redsymbol),\n",
"which appear at the top of the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o-j0vSxEvLr5"
},
"outputs": [],
"source": [
"!head -n 3 tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d2iJU5jlvLr5"
},
"source": [
"The core idea of strict mode is to fail more loudly.\n",
"This is a desirable behavior of scripts,\n",
"like the ones we're writing,\n",
"even though it's an undesirable behavior for an interactive shell --\n",
"it would be unpleasant to be logged out every time you hit an error.\n",
"\n",
"`set -u` means scripts fail if a variable's value is `u`nset,\n",
"i.e. not defined.\n",
"Otherwise bash is perfectly happy to allow you to reference undefined variables.\n",
"The result is just an empty string, which can lead to maddeningly weird behavior.\n",
"\n",
"`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n",
"rather than using the exit code of the last command.\n",
"Unix tools are perfectly happy to work on nonsense input,\n",
"like sorting error messages, instead of the filenames you meant to send.\n",
"\n",
"You can read more about these choices\n",
"[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n",
"and considerations for working with other non-conforming scripts in \"strict mode\"\n",
"and for handling resource teardown when scripts error out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s1XqsrU_XWWS"
},
"source": [
"# Testing ML Codebases"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPNzeq3NYF2W"
},
"source": [
"## Testing Python code with `pytests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zq5e_x6gc9Vu"
},
"source": [
"\n",
"ML codebases are Python first and foremost, so first let's get some Python tests going."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DC3GxYz6_R9"
},
"source": [
"At a basic level,\n",
"we can write functions that `assert`\n",
"that our code behaves as expected in\n",
"a given scenario and include it in the same module."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rvd-GNwv63W1"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.metrics import test_character_error_rate\n",
"\n",
"test_character_error_rate??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iVB2TsQS5BTq"
},
"source": [
"The standard tool for testing Python code is\n",
"[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n",
"\n",
"We can use it as a command-line tool in a variety of ways,\n",
"including to execute these kinds of tests.\n",
"\n",
"If passed a filename, `pytest` will look for\n",
"any classes that start with `Test` or\n",
"any functions that start with `test_` and run them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u8sQguyJvLr6",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest text_recognizer/lit_models/metrics.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "92tkBCllvLr6"
},
"source": [
"After the results of the tests (pass or fail) are returned,\n",
"you'll see a report of \"coverage\" from\n",
"[`codecov`](https://about.codecov.io/).\n",
"\n",
"This coverage report tells us which files and how many lines in those files\n",
"were at touched by the testing suite."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PllSUe0s5xvU"
},
"source": [
"We do not actually need to provide the names of files with tests in them to `pytest`\n",
"in order for it to run our tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4qOBHJnTZM9x"
},
"source": [
"By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n",
"\n",
"It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n",
"to separate these from the rest of your code\n",
"in a folder or folders named `tests`,\n",
"rather than scattering them around the repo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "acjsYTNSvLr6"
},
"outputs": [],
"source": [
"!ls text_recognizer/tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WZQQZUF0vLr6"
},
"source": [
"Let's take a look at a specific example:\n",
"the tests for some of our utilities around\n",
"custom PyTorch Lightning `Callback`s."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oS0xKv1evLr6"
},
"outputs": [],
"source": [
"from text_recognizer.tests import test_callback_utils\n",
"\n",
"\n",
"test_callback_utils.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lko8msn-vLr7"
},
"source": [
"Notice that we can easily import this as a module!\n",
"\n",
"That's another benefit of organizing tests into specialized files."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5A85FUNv75Fr"
},
"source": [
"The particular utility we're testing\n",
"here is designed to prevent crashes:\n",
"it checks for a particular type of error and turns it into a warning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jl4-DiVe76sw"
},
"outputs": [],
"source": [
"from text_recognizer.callbacks.util import check_and_warn\n",
"\n",
"check_and_warn??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B6E0MhduvLr7"
},
"source": [
"Error-handling code is a common cause of bugs,\n",
"a fact discovered\n",
"[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n",
"so it's very important to test it well!\n",
"\n",
"We start with a very basic test,\n",
"which does not touch anything\n",
"outside of the Python standard library,\n",
"even though this tool is intended to be used\n",
"with more complex features of third-party libraries,\n",
"like `wandb` and `tensorboard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xx5koQmJvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_simple??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MZe9-JVjvLr7"
},
"source": [
"Here, we are just testing the core logic.\n",
"This test won't catch many bugs,\n",
"but when it does fail, something has gone seriously wrong.\n",
"\n",
"These kinds of tests are important for resolving a bug:\n",
"we learn nearly as much from the tests that passed\n",
"as we did from the tests that failed.\n",
"If this test has failed, possibly along with others,\n",
"we can rule out an issue in one of the large external codebases\n",
"touched in the other tests, saving us lots of time in our troubleshooting.\n",
"\n",
"The reasoning for the test is explained in the docstrings, \n",
"which are close to the code.\n",
"\n",
"Your test suite should be as welcoming\n",
"as the rest of your codebase!\n",
"The people reading it, for example yourself in six months, \n",
"are likely upset and in need of some kindness.\n",
"\n",
"More practically, we want keep our time to resolve errors as short as possible,\n",
"and five minutes to write a good docstring now\n",
"can save five minutes during an outage, when minutes really matter."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Om9k-uXhvLr7"
},
"source": [
"That basic test is a start, but it's not enough by itself.\n",
"There's a specific error case that triggered the addition of this code.\n",
"\n",
"So we test that it's handled as expected."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fjbsb5FvvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_tblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGAIZTUjvLr7"
},
"source": [
"That test can fail if the libraries change around our code,\n",
"i.e. if the `TensorBoardLogger` gets a `log_table` method.\n",
"\n",
"We want to be careful when making assumptions\n",
"about other people's software,\n",
"especially for fast-moving libraries like Lightning.\n",
"If we test that those assumptions hold willy-nilly,\n",
"we'll end up with tests that fail because of\n",
"harmless changes in our dependencies.\n",
"\n",
"Tests that require a ton of maintenance and updating\n",
"without leading to code improvements soak up\n",
"more engineering time than they save\n",
"and cause distrust in the testing suite.\n",
"\n",
"We include this test because `TensorBoardLogger` getting\n",
"a `log_table` method will _also_ change the behavior of our code\n",
"in a breaking way, and we want to catch that before it breaks\n",
"a model training job."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jsy95KAvvLr7"
},
"source": [
"Adding error handling can also accidentally kill the \"happy path\"\n",
"by raising an error incorrectly.\n",
"\n",
"So we explicitly test the _absence of an error_,\n",
"not just its presence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LRlIOkjmvLr8"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_wandblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osiqpLynvLr8"
},
"source": [
"There are more tests we could build, e.g. manipulating classes and testing the behavior,\n",
"testing more classes that might be targeted by `check_and_warn`, or\n",
"asserting that warnings are raised to the command line.\n",
"\n",
"But these three basic tests are likely to catch most changes that would break our code here,\n",
"and they're a lot easier to write than the others.\n",
"\n",
"If this utility starts to get more usage and become a critical path for lots of features, we can always add more!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dm285JE5vLr8"
},
"source": [
"## Interleaving testing and documentation with `doctests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHWQvgA8vLr8"
},
"source": [
"One function of tests is to build user/reader confidence in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wrhiJBXFvLr8"
},
"source": [
"One function of documentation is to build user/reader knowledge in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1vu12LDhvLr8"
},
"source": [
"These functions are related. Let's put them together:\n",
"put code in a docstring and test that code.\n",
"\n",
"This feature is part of the\n",
"Python standard library via the\n",
"[`doctest` module](https://docs.python.org/3/library/doctest.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rmfIOwXd-Qt7"
},
"source": [
"Here's an example from our `torch` utilities.\n",
"\n",
"The `first_appearance` function can be used to\n",
"e.g. quickly look for stop tokens,\n",
"giving the length of each sequence."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZzURGcD9vLr8"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.util import first_appearance\n",
"\n",
"\n",
"first_appearance??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0VtYcJ1WvLr8"
},
"source": [
"Notice that in the \"Examples\" section,\n",
"there's a short block of code formatted as a\n",
"Python interpreter session,\n",
"complete with outputs.\n",
"\n",
"We can copy and paste that code and\n",
"check that we get the right outputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dj4lNOxJvLr9"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y9AWHFoIvLr9"
},
"source": [
"We can run the test with `pytest` by passing a command line argument,\n",
"`--doctest-modules`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JMaAxv5ovLr9"
},
"outputs": [],
"source": [
"!pytest --doctest-modules text_recognizer/lit_models/util.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6-2_aOUfvLr9"
},
"source": [
"With the\n",
"[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n",
"running `doctest`s happens automatically\n",
"when `pytest` is invoked."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "my_keokPvLr9"
},
"source": [
"## Basic tests for data code"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qj3Bq_j2_A8o"
},
"source": [
"ML code can be hard to test\n",
"since it involes very heavy artifacts, like models and data,\n",
"and very expensive jobs, like training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DT5OmgrQvLr9"
},
"source": [
"For testing our data-handling code in the FSDL codebase,\n",
"we mostly just use `assert`s,\n",
"which throw errors when behavior differs from expectation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bdzn5g4TvLr9"
},
"outputs": [],
"source": [
"!grep \"assert\" -r text_recognizer/data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2aTlfu4_vLr-"
},
"source": [
"This isn't great practice,\n",
"especially as a codebase grows,\n",
"because we can't easily know when these are executed\n",
"or incorporate them into\n",
"testing automation and coverage analysis tools."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IaMTdmbZ_mkW"
},
"source": [
"So it's preferable to collect up these assertions of simple data properties\n",
"into tests that are run like our other tests.\n",
"\n",
"The test below checks whether any data is leaking\n",
"between training, validation, and testing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qx7cxiDdvLr-"
},
"outputs": [],
"source": [
"from text_recognizer.tests.test_iam import test_iam_data_splits\n",
"\n",
"\n",
"test_iam_data_splits??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16TJwhd1vLr-"
},
"source": [
"Notice that we were able to load the test into the notebook\n",
"because it is in a module,\n",
"and so we can run it here as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mArITFkYvLr-"
},
"outputs": [],
"source": [
"test_iam_data_splits()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E4F2uaclvLr-"
},
"source": [
"But we're checking something pretty simple here,\n",
"so the new code in each test is just a single line.\n",
"\n",
"What if we wanted to test more complex properties,\n",
"like comparing rows or calculating statistics?\n",
"\n",
"We'll end up writing more complex code that might itself have subtle bugs,\n",
"requiring tests for our tests and suffering from\n",
"\"tester's regress\".\n",
"\n",
"This is the phenomenon,\n",
"named by analogy with\n",
"[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n",
"in sociology of science,\n",
"where the validity of our tests is itself\n",
"up for dispute only resolvable by testing the tests,\n",
"but those tests are themselves possibly invalid."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nUGT06gdvLr-"
},
"source": [
"We cut this Gordian knot by using\n",
"a library or framework that is well-tested.\n",
"\n",
"We recommend checking out\n",
"[`great_expectations`](https://docs.greatexpectations.io/docs/)\n",
"if you're looking for a high-quality data testing tool."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dQ5vNsq3vLr-"
},
"source": [
"Especially with data, some tests are particularly \"heavy\" --\n",
"they take a long time,\n",
"and we might want to run them\n",
"on different machines\n",
"and on a different schedule\n",
"than our other tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xephcb0LvLr-"
},
"source": [
"For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n",
"\n",
"We can't just use a cached version of the data,\n",
"since that won't actually execute the code!\n",
"\n",
"This test will take\n",
"as long to run\n",
"and consume as many resources as\n",
"a full download of the data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YSN4w2EqvLr-"
},
"source": [
"`pytest` allows the separation of tests\n",
"into suites with `mark`s,\n",
"which \"tag\" tests with names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V0rScrcXvLr_",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest --markers | head -n 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lr5Ca7B0vLr_"
},
"source": [
"We can choose to run tests with a given mark\n",
"or to skip tests with a given mark, \n",
"among other basic logical operations around combining and filtering marks,\n",
"with `-m`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xmw-Eb1ZvLr_"
},
"outputs": [],
"source": [
"!wandb login # one test requires wandb authentication\n",
"\n",
"!pytest -m \"not data and not slow\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5LuERxOXX_UJ"
},
"source": [
"## Testing training with memorization tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AnWLN4lRvLsA"
},
"source": [
"Training is the process by which we convert inert data into executable models,\n",
"so it is dependent on both.\n",
"\n",
"We decouple checking whether the script has a critical bug\n",
"from whether the data or model code is broken\n",
"by testing on some basic \"fake data\",\n",
"based on a utility from `torchvision`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k4NIc3uWvLsA"
},
"outputs": [],
"source": [
"from text_recognizer.data import FakeImageData\n",
"\n",
"\n",
"FakeImageData.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deN0swwlvLsA"
},
"source": [
"We then test on the actual data with a smaller version of the real model.\n",
"\n",
"We use the Lightning `--fast_dev_run` feature,\n",
"which sets the number of training, validation, and test batches to `1`.\n",
"\n",
"We use a smaller version so that this test can run in just a few minutes\n",
"on a CPU without acceleration.\n",
"\n",
"That allows us to run our tests in environments without GPUs,\n",
"which saves on costs for executing tests.\n",
"\n",
"Here's the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z4J0_uD9vLsA"
},
"outputs": [],
"source": [
"!cat training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-7u9zS1vLsA",
"scrolled": false
},
"outputs": [],
"source": [
"! ./training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UTzfo11KClV3"
},
"source": [
"The above tests don't actaully check\n",
"whether any learning occurs,\n",
"they just check\n",
"whether training runs mechanically,\n",
"without any errors.\n",
"\n",
"We also need a\n",
"[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n",
"for learning.\n",
"For that we recommending checking whether\n",
"the model can learn the right\n",
"outputs for a single batch --\n",
"to \"memorize\" the outputs for\n",
"a particular input.\n",
"\n",
"This memorization test won't\n",
"catch every bug or issue in training,\n",
"which is notoriously difficult,\n",
"but it will flag\n",
"some of the most serious issues."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DVSp3aAvLsA"
},
"source": [
"The script below runs a memorization test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2DFVVrxpvLsA"
},
"source": [
"It takes up to two arguments:\n",
"a `MAX`imum number of `EPOCHS` to run for and\n",
"a `CRITERION` value of the loss to test against.\n",
"\n",
"The test passes if the loss is lower than the `CRITERION` value\n",
"after the `MAX`imum number of `EPOCHS` has passed."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oEhJH0e5vLsB"
},
"source": [
"The important line in this script is the one that invokes our training script,\n",
"`training/run_experiment.py`.\n",
"\n",
"The arguments to `run_experiment` have been tuned for maximum possible speed:\n",
"turning off regularization, shrinking the model,\n",
"and skipping parts of Lightning that we don't want to test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T-fFs1xEvLsB"
},
"outputs": [],
"source": [
"!cat training/tests/test_memorize_iam.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X-47tUA_YNGe"
},
"source": [
"If you'd like to see what a memorization run looks like,\n",
"flip the `running_memorization` flag to `True`\n",
"and watch the results stream in to W&B.\n",
"\n",
"The cell should run in about ten minutes on a commodity GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GwTEsZwKvLsB"
},
"outputs": [],
"source": [
"%%time\n",
"running_memorization = False\n",
"\n",
"if running_memorization:\n",
" max_epochs = 1000\n",
" loss_criterion = 0.05\n",
" !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Troubleshooting model speed with the PyTorch Profiler"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpbN-Om2Drf-"
},
"source": [
"Testing code is only half the story here:\n",
"we also need to fix the issues that our tests flag.\n",
"This is the process of troubleshooting.\n",
"\n",
"In this lab,\n",
"we'll focus on troubleshooting model performance issues:\n",
"what do to when your model runs too slowly."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NZzwELPXvLsD"
},
"source": [
"Troubleshooting deep neural networks for speed is challenging.\n",
"\n",
"There are at least three different common approaches,\n",
"each with an increasing level of skill required:\n",
"\n",
"1. Follow best practices advice from others\n",
"([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n",
"[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n",
"2. Take code that runs slowly and use empirical observations to iteratively improve it.\n",
"3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n",
"\n",
"For the full stack deep learning engineer,\n",
"the final level is typically out of reach,\n",
"unless you're specializing in the model performance\n",
"part of the stack in particular.\n",
"\n",
"So we recommend reaching the middle level,\n",
"and this segment of the lab walks through the\n",
"tools that make this easier."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_yp87UrFZ8M"
},
"source": [
"Because neural network training involves GPU acceleration,\n",
"generic Python profiling tools like\n",
"[`py-spy`](https://github.com/benfred/py-spy)\n",
"won't work, and\n",
"we'll need tools specialized for tracing and profiling DNN training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yspsYVFGEyZm"
},
"source": [
"In general, these tools are for observing what happens while your code is executing:\n",
"_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n",
"\n",
"Because they help us observe the execution in detail,\n",
"they will also help us understand just what is going on during\n",
"a PyTorch training step in greater detail."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YqXq2hKuvLsE"
},
"source": [
"To support profiling and tracing,\n",
"we've added a new argument to `training/run_experiment.py`, `--profile`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z_GMMViWvLsE"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZldoksHPvLsE"
},
"source": [
"As with experiment management, this relies mostly on features of PyTorch Lightning,\n",
"which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n",
"and we just add a few lines of customization:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F2iJ0_A6vLsE"
},
"outputs": [],
"source": [
"!cat training/run_experiment.py | grep args.profile -A 5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aw3ppgndvLsE"
},
"source": [
"For more on profiling with Lightning, see the\n",
"[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCAmNW3QEtcD"
},
"source": [
"The cell below runs an epoch of training with tracing and profiling turned on\n",
"and then saves the results locally and to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t4o3ylDgr46F",
"scrolled": false
},
"outputs": [],
"source": [
"import glob\n",
"\n",
"import torch\n",
"import wandb\n",
"\n",
"from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n",
"\n",
"\n",
"# make it easier to separate these from training runs\n",
"%env WANDB_JOB_TYPE=profile\n",
"\n",
"batch_size = 16\n",
"num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n",
"gpus = 1 # must be run with accelerator\n",
"\n",
"%run training/run_experiment.py --wandb --profile \\\n",
" --max_epochs=1 \\\n",
" --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n",
" --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n",
" --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n",
"\n",
"latest_expt = wandb.run\n",
"\n",
"try: # add execution trace to logged and versioned binaries\n",
" folder = wandb.run.dir\n",
" trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n",
" trace_file = glob.glob(trace_matcher)[0]\n",
" trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n",
" trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n",
" wandb.log_artifact(trace_at)\n",
"except IndexError:\n",
" print(\"trace not found\")\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ePTkS3EqO5tN"
},
"source": [
"We get out a table of statistics in the terminal,\n",
"courtesy of Lightning.\n",
"\n",
"Each row lists an operation\n",
"and and provides information,\n",
"described in the column headers,\n",
"about the time spent on that operation\n",
"across all the training steps we profiled.\n",
"\n",
"With practice, some useful information can be read out from this table,\n",
"but it's better to start from both a less detailed view,\n",
"in the TensorBoard dashboard,\n",
"and a more detailed view,\n",
"using the Chrome Trace viewer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TzV62f3c7-Bi"
},
"source": [
"## High-level statistics from the PyTorch Profiler in TensorBoard"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mNPKXkYw8NWd"
},
"source": [
"Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CbItwuT88eAV"
},
"outputs": [],
"source": [
"your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n",
"\n",
"print(your_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jE_LooMYHFpF"
},
"source": [
"If at any point you run into issues,\n",
"like the description not matching what you observe,\n",
"check out one of our example runs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "za2zybSwIo5C"
},
"outputs": [],
"source": [
"example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n",
"print(example_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xlrhl1n4HYU6"
},
"source": [
"Once the TensorBoard session has loaded up,\n",
"we are dropped into the Overview\n",
"(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n",
"for an example).\n",
"\n",
"In the top center, we see the **GPU Summary** for our system.\n",
"\n",
"In addition to the name of our GPU,\n",
"there are a few configuration details and top-level statistics.\n",
"They are (tersely) documented\n",
"[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MmBhUDgDLhd1"
},
"source": [
"- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n",
"this is effectively a coarse \"version number\" for your GPU hardware.\n",
"It indexes which features are available,\n",
"with more advanced features being available only at higher compute capabilities.\n",
"It does not directly index the speed or memory of the GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "voUgT6zuLyi0"
},
"source": [
"- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yl-IndtXE4b4"
},
"source": [
"- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n",
"for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n",
"Tensor Cores.\n",
"If you're running on an older GPU without Tensor Cores,\n",
"you should consider upgrading.\n",
"If you're running a more recent GPU but not seeing Tensor Core usage,\n",
"you should switch to single precision floating point numbers,\n",
"which Tensor Cores are specialized on."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XxcUf0bBNXy_"
},
"source": [
"- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n",
"at a lower level than just whether something is running at all,\n",
"as in utilization.\n",
"Unlike utilization, reaching 100% is not generally feasible\n",
"and sometimes not desirable.\n",
"Increasing these numbers requires expertise in\n",
"CUDA programming, so we'll target utilization instead."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A88pQn4YMMKc"
},
"source": [
"- **Execution Summary**: This table and pie chart indicates\n",
"how much time within a profiled step\n",
"was spent in each category.\n",
"The value for \"kernel\" execution here\n",
"is equal to the GPU utilization,\n",
"and we want that number to be as close to 100%\n",
"as possible.\n",
"This summary helps us know which\n",
"other operations are taking time,\n",
"like memory being copied between CPU and GPU (`memcpy`)\n",
"or `DataLoader`s executing on the CPU,\n",
"so we can decide where the bottleneck is."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6qjW1RlTQRPv"
},
"source": [
"At the very bottom, you'll find a\n",
"**Performance Recommendation**\n",
"tab that sometimes suggests specific methods for improving performance.\n",
"\n",
"If this tab makes suggestions, you should certainly take them!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pWY5AhrcRQmJ"
},
"source": [
"For more on using the profiler in TensorBoard,\n",
"including some of the other, more detailed views\n",
"available view the \"Views\" dropdown menu, see\n",
"[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQwrPY_H77H8"
},
"source": [
"## Going deeper with the Chrome Trace Viewer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yhwo7fslvLsH"
},
"source": [
"So far, we've seen summary-level information about our training steps\n",
"in the table from Lightning and in the TensorBoard Overview.\n",
"These give aggregate statistics about the computations that occurred,\n",
"but understanding how to interpret those statistics\n",
"and use them to speed up our networks\n",
"requires understanding just what is\n",
"happening in our training step.\n",
"\n",
"Fundamentally,\n",
"all computations are processes that unfold in time.\n",
"\n",
"If we want to really understand our training step,\n",
"we need to display it that way:\n",
"what operations were occurring,\n",
"on both the CPU and GPU,\n",
"at each moment in time during the training step.\n",
"\n",
"This information on timing is collected in the trace.\n",
"One of the best tools for viewing the trace over time\n",
"is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wUkZItxYc20A"
},
"source": [
"Let's tour the trace we just logged\n",
"with an aim to really understanding just\n",
"what is happening when we call\n",
"`training_step`\n",
"and by extension `.forward`, `.backward`, and `optimizer.step`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9w9F2UA7Qctg"
},
"source": [
"The Chrome Trace Viewer is built into W&B,\n",
"so we can view our traces in their interface.\n",
"\n",
"The cell below embeds the trace inside the notebook,\n",
"but you may wish to open it separately,\n",
"with the \"Open page\" button or by navigating to the URL,\n",
"so that you can interact with it\n",
"as you read the description below.\n",
"Display directly on W&B is also a bit less temperamental\n",
"than display on W&B inside a notebook.\n",
"\n",
"Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n",
"so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n",
"It also can interact poorly with browser extensions (e.g. ad blockers),\n",
"so you may need to deactivate them temporarily in order to see it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OMUs4aby6Rfd"
},
"outputs": [],
"source": [
"trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n",
"trace_url = trace_files_url + \"training_step.pt.trace.json\"\n",
"\n",
"example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n",
"\n",
"print(trace_url)\n",
"IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qNVpGeQtQjMG"
},
"source": [
"> **Heads up!** We're about to do a tour of the\n",
"> precise details of the tracing information logged\n",
"> during the execution of the training code.\n",
"> The only way to learn how to troubleshoot model performance\n",
"> empirically is to look at the details,\n",
"> but the details depend on the precise machine being used\n",
"> -- GPU and CPU and RAM.\n",
"> That means even within Colab,\n",
"> these details change from session to session.\n",
"> So if you don't observe a phenomenon or feature\n",
"> described in the tour below, check out\n",
"> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n",
"> on W&B while reading through the next section of the lab,\n",
"> and return to your trace once you understand the trace viewer better at the end.\n",
"> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n",
"> can sometimes be a bit janky."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kXMcBhnCgdN_"
},
"source": [
"This trace reveals, in nanosecond-level detail,\n",
"what's going on inside of a `training_step`\n",
"on both the GPU and the CPU.\n",
"\n",
"Time is on the horizontal axis.\n",
"Colored bars represent method calls,\n",
"and the methods called by a method are placed underneath it vertically,\n",
"a visualization known as an\n",
"[icicle chart](https://www.brendangregg.com/flamegraphs.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "67BsNzDfVIeg"
},
"source": [
"Let's orient ourselves with some gross features:\n",
"the forwards pass,\n",
"GPU kernel execution,\n",
"the backwards pass,\n",
"and the optimizer step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBEFgtRCKqrh"
},
"source": [
"### The forwards pass"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nYhiWesVMjK"
},
"source": [
"Type in `resnet` to the search bar in the top-right.\n",
"\n",
"This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n",
"\n",
"It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n",
"\n",
"You can click the arrows next to that tile to partially collapse these blocks.\n",
"\n",
"Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n",
"It should be at roughly the same height.\n",
"\n",
"Clear the search bar so that the trace is in color.\n",
"Zoom in on the area of the forwards pass\n",
"using the \"zoom\" tool in the floating toolbar,\n",
"so you can see more detail.\n",
"The zoom tool is indicated by a two-headed arrow\n",
"pointing into and out of the screen.\n",
"\n",
"Switch to the \"drag\" tool,\n",
"represented by a four-headed arrow.\n",
"Click-and-hold to use this tool to focus\n",
"on different parts of the timeline\n",
"and click on the individual colored boxes\n",
"to see details about a particular method call.\n",
"\n",
"As we go down in the icicle chart,\n",
"we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n",
"to much more precise `cudnn` and `cuda` operations\n",
"(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n",
"\n",
"`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n",
"is the tensor math library in PyTorch\n",
"that links to specific backends like `cudnn`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fq181ybIvLsH"
},
"source": [
"### GPU kernel execution"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbkWp5aKvLsH"
},
"source": [
"Towards the bottom, you should see a section labeled \"GPU\".\n",
"The label appears on the far left.\n",
"\n",
"Within it, you'll see one or more \"`stream`s\".\n",
"These are units of work on a GPU,\n",
"akin loosely to threads on the CPU.\n",
"\n",
"When there are colored bars in this area,\n",
"the GPU is doing work of some kind.\n",
"The fraction of this bar that is filled in with color\n",
"is the same as the \"GPU Utilization %\" we've seen previously.\n",
"So the first thing to visually assess\n",
"in a trace view of PyTorch code\n",
"is what fraction of this area is filled with color.\n",
"\n",
"In CUDA, work is queued up to be\n",
"placed into streams and completed, on the GPU,\n",
"in a distributed and asynchronous manner.\n",
"\n",
"The selection of which work to do\n",
"is happening on the CPU,\n",
"and that's what we were looking at above.\n",
"\n",
"The CPU and the GPU have to work together to coordinate\n",
"this work.\n",
"\n",
"Type `cuda` into the search bar and you'll see these coordination operations happening:\n",
"`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n",
"\n",
"Running the same PyTorch model\n",
"with the same high level operations like `Conv2d` in different versions of PyTorch,\n",
"on different GPUs, and even on tensors of different sizes will result\n",
"in different choices of concrete kernel operation,\n",
"e.g. different matrix multiplication algorithms.\n",
"\n",
"Type `sync` into the search bar and you'll see places where either work on the GPU\n",
"or work on the CPU needs to await synchronization,\n",
"e.g. copying data from the CPU to the GPU\n",
"or the CPU waiting to decide what to do next\n",
"on the basis of the contents of a tensor.\n",
"\n",
"If you see a \"sync\" block above an area\n",
"where the stream on the GPU is empty,\n",
"you've got a performance bottleneck due to synchronization\n",
"between the CPU and GPU.\n",
"\n",
"To resolve the bottleneck,\n",
"head up the icicle chart until you reach the recognizable\n",
"PyTorch modules and operations.\n",
"Find where they are called in your PyTorch module.\n",
"That's a good place to review your code to understand why the synchronization is happening\n",
"and removing it if it's not necessary."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XeMPbu_jvLsI"
},
"source": [
"### The backwards pass\n",
"\n",
"Type in `backward` into the search bar.\n",
"\n",
"This will highlight components of our backwards pass.\n",
"\n",
"If you read it from left to right,\n",
"you'll see that it begins by calculating the loss\n",
"(`NllLoss2DBackward` in the search bar if you can't find it)\n",
"and ends by doing a `ConvolutionBackward`,\n",
"the first layer of the ResNet.\n",
"It is, indeed, backwards.\n",
"\n",
"Like the forwards pass,\n",
"the backwards pass also involves the CPU\n",
"telling the GPU which kernels to run.\n",
"It's typically run in a separate\n",
"thread from the forwards pass,\n",
"so you'll see it separated out from the forwards pass\n",
"in the trace viewer.\n",
"\n",
"Generally, there's no need to specifically optimize the backwards pass --\n",
"removing bottlenecks in the forwards pass results in a fast backwards pass.\n",
"\n",
"One reason why is that these two passes are just\n",
"\"transposes\" of one another,\n",
"so they share a lot of properties,\n",
"and bottlenecks in one become bottlenecks in the other.\n",
"We can choose to optimize either one of the two.\n",
"But the forwards pass is under our direct control,\n",
"so it's easier for us to reason about.\n",
"\n",
"Another reason is that the forwards pass is more likely to have bottlenecks.\n",
"The forwards pass is a dynamic process,\n",
"with each line of Python adding more to the compute graph.\n",
"Backwards passes, on the other hand, use a static compute graph,\n",
"the one just defined by the forwards pass,\n",
"so more optimizations are possible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gWiDw0vCvLsI"
},
"source": [
"### The optimizer step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ndfkzEdnvLsI"
},
"source": [
"Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n",
"\n",
"As with the two passes,\n",
"we are still using the CPU\n",
"to launch kernels on the GPU.\n",
"But now the CPU is looping,\n",
"in Python, over the parameters\n",
"and applying the ADAM updates rules to each.\n",
"\n",
"We now know enough to see that\n",
"this is not great for our GPU utilization:\n",
"there are many areas of gray\n",
"in between the colored bars\n",
"in the GPU stream in this area.\n",
"\n",
"In the time it takes CUDA to multiply\n",
"thousands of numbers,\n",
"Python has not yet finished cleaning up\n",
"after its request for that multiplication.\n",
"\n",
"As of writing in August 2022,\n",
"more efficient optimizers are not a stable part of PyTorch (v1.12), but\n",
"[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n",
"and stable implementations outside of PyTorch.\n",
"The standard implementations are in\n",
"[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n",
"not to be confused with the\n",
"[Apex Optimizers Project](https://www.apexoptimizers.com/),\n",
"which is a collection of fitness-themed cheetah NFTs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WX0jxeafvLsI"
},
"source": [
"## Take-aways for PyTorch performance bottleneck troubleshooting"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CugD-bK2vLsI"
},
"source": [
"Our goal here was to learn some basic principles and tools for bottlenecking\n",
"the most common issues and the lowest-hanging fruit in PyTorch code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SwHwJkVMHYGA"
},
"source": [
"\n",
"Here's an overview in terms of a \"host\",\n",
"generally the CPU,\n",
"and a \"device\", here the GPU.\n",
"\n",
"- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n",
"- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n",
" - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n",
"- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n",
"stores information about the contents of tensors,\n",
"not just their metadata."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gntx28p9cBP5"
},
"source": [
"Towards that goal, we viewed the trace to get an understanding of\n",
"what's going on inside a PyTorch training step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AKvZGPnkeXvq"
},
"source": [
"Here's what we've means in terms of troubleshooting bottlenecks.\n",
"\n",
"We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n",
"before the previous kernel finishes.\n",
"\n",
"Ideally, the CPU is actually getting far _ahead_ of execution\n",
"on the GPU.\n",
"If the CPU makes it all the way through the backwards pass before the GPU is done,\n",
"that's great!\n",
"The GPU(s) are the expensive part,\n",
"and it's easy to use multiprocessing so that\n",
"the CPU has other things to do.\n",
"\n",
"This helps explain at least one common piece of advice:\n",
"the larger our batches are,\n",
"the more work the GPU has to do for the same work done by the CPU,\n",
"and so the better our utilization will be."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMztpa-TccH4"
},
"source": [
"We operationalize our desire to never be waiting on the CPU with a simple metric:\n",
"**100% GPU utilization**, meaning a kernel is running at all times.\n",
"\n",
"This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n",
"\n",
"You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7kYBygfScR6z"
},
"source": [
"Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n",
"1. **The CPU is too weak**.\n",
"Because so much of the discussion around DNN performance is about GPUs,\n",
"it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n",
"_Resolution_:\n",
"Use nice CPUs, like\n",
"[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n",
"2. **Too much Python during the `training_step`**.\n",
"Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n",
"that takes longer than running an entire layer.\n",
"_Resolution_:\n",
"Look for low utilization areas of the trace\n",
"and check what's happening on the CPU at that time\n",
"and carefully review the Python code being executed.\n",
"3. **Unnecessary Host/Device synchronization**.\n",
"If one of your operations depends on the values in a tensor,\n",
"like `if xs.mean() >= 0`,\n",
"you'll induce a synchronization between\n",
"the host and the device and possibly lead\n",
"to an expensive and slow copy of data.\n",
"_Resolution_:\n",
"Replace these operations as much as possible\n",
"with purely array-based calculations.\n",
"4. **Bottlenecking on the DataLoader**.\n",
"In addition to coordinating the work on the GPU,\n",
"CPUs often perform heavy data operations,\n",
"including communication over the network\n",
"and writing to/reading from disk.\n",
"These are generally done in parallel to the forwards\n",
"and backwards passes,\n",
"but if they don't finish before that happens,\n",
"they will become the bottleneck.\n",
"_Resolution_:\n",
"Get better hardware for compute,\n",
"memory, and network.\n",
"For software solutions, the answer \n",
"is a bit more complex and application-dependent.\n",
"For generic tips, see\n",
"[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n",
"in the PyTorch forums.\n",
"For techniques in computer vision, see\n",
"[the FFCV library](https://github.com/libffcv/ffcv)\n",
"and for techniques in NLP, see e.g.\n",
"[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n",
"and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i2WYS8bQvLsJ"
},
"source": [
"### Further steps in making DNNs go brrrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T0wW2_lRKfY1"
},
"source": [
"It's important to note that utilization\n",
"is just an easily measured metric\n",
"that can reveal common bottlenecks.\n",
"Having high utilization does not automatically mean\n",
"that your performance is fully optimized.\n",
"\n",
"For example,\n",
"synchronization events between GPUs\n",
"are counted as kernels,\n",
"so a deadlock during distributed training\n",
"can show up as 100% utilization,\n",
"despite literally no useful work occurring.\n",
"\n",
"Just switching to \n",
"double precision floats, `--precision=64`,\n",
"will generally lead to much higher utilization.\n",
"The GPU operations take longer\n",
"for roughly the same amount of CPU effort,\n",
"but the added precision brings no benefit.\n",
"\n",
"In particular, it doesn't make for models\n",
"that perform better on our correctness metrics,\n",
"like loss and accuracy.\n",
"\n",
"Another useful yardstick to add\n",
"to utilization is examples per second,\n",
"which incorporates how quickly the model is processing data examples\n",
"and calculating gradients.\n",
"\n",
"But really,\n",
"the gold star is _decrease in loss per second_.\n",
"This metric connects model design choices\n",
"and hyperparameters with purely engineering concerns,\n",
"so it disrespects abstraction barriers\n",
"and doesn't generally lead to actionable recommendations,\n",
"but it is, in the end, the real goal:\n",
"make the loss go down faster so we get better models sooner."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EFzPsplfdo_o"
},
"source": [
"For PyTorch internals abstractly,\n",
"see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"For more on performance considerations in PyTorch,\n",
"see [Horace He's blog post](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yq6-S6TC38AY"
},
"source": [
"### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n",
"\n",
"One of the most important features for making\n",
"PyTorch run quickly is the\n",
"`MultiprocessingDataLoader`,\n",
"which executes batching of data in a separate process\n",
"from the forwards and backwards passes.\n",
"\n",
"By default in PyTorch,\n",
"this feature is actually turned off,\n",
"via the `DataLoader` argument `num_workers`\n",
"having a default value of `0`,\n",
"but we set the `DEFAULT_NUM_WORKERS`\n",
"to a value based on the number of CPUs\n",
"available on the system running the code.\n",
"\n",
"Re-run the profiling cell,\n",
"but set `num_workers` to `0`\n",
"to turn off multiprocessing.\n",
"\n",
"Compare and contrast the two traces,\n",
"both for total runtime\n",
"(see the time axis at the top of the trace)\n",
"and for utilization.\n",
"\n",
"If you're unable to run the profiles,\n",
"see the results\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n",
"which juxtaposes two traces,\n",
"with in-process dataloading on the left and\n",
"multiprocessing dataloading on the right."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T2i_a5eVeIoA"
},
"source": [
"The file below incorrectly implements and then incorrectly tests\n",
"a simple PyTorch utility for adding five to every entry of a tensor\n",
"and then calculating the sum.\n",
"\n",
"Even worse, it does it with horrible style!\n",
"\n",
"The cells below apply our linting checks\n",
"(after automatically fixing the formatting)\n",
"and run the test.\n",
"\n",
"Fix all of the lints,\n",
"implement the function correctly,\n",
"and then implement some basic tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wSon2fB5VVM_"
},
"source": [
"- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aYiRvU4HA84t"
},
"outputs": [],
"source": [
"%%writefile training/fixme.py\n",
"import torch\n",
"from training import run_experiment\n",
"from numpy import *\n",
"import random\n",
"from pathlib import Path\n",
"\n",
"\n",
"\n",
"\n",
"def add_five_and_sum(tensor):\n",
" # this function is not implemented right,\n",
" # but it's supposed to add five to all tensor entries and sum them up\n",
" return 1\n",
"\n",
"def test_add_five_and_sum():\n",
" # and this test isn't right either! plus this isn't exactly a docstring\n",
" all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n",
" all_fives = 5 * all_ones\n",
" assert False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EXJpmvuzT1w0"
},
"outputs": [],
"source": [
"!pre-commit run black --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SRO-oJfdUrcQ"
},
"outputs": [],
"source": [
"!cat training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jM8NHxVbSEQD"
},
"outputs": [],
"source": [
"!pre-commit run --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kj0VMBSndtkc"
},
"outputs": [],
"source": [
"!pytest training/fixme.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab05_troubleshooting.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab05/tasks/lint.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
# apply automatic formatting
echo "black"
pre-commit run black || FAILURE=true
# check for python code style violations, see .flake8 for details
echo "flake8"
pre-commit run flake8 || FAILURE=true
# check for shell scripting style violations and common bugs
echo "shellcheck"
pre-commit run shellcheck || FAILURE=true
# check python types
echo "mypy"
pre-commit run mypy || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Linting failed"
exit 1
fi
echo "Linting passed"
exit 0
================================================
FILE: lab05/text_recognizer/__init__.py
================================================
"""Modules for creating and running a text recognizer."""
================================================
FILE: lab05/text_recognizer/callbacks/__init__.py
================================================
from .model import ModelSizeLogger
from .optim import LearningRateMonitor
from . import imtotext
from .imtotext import ImageToTextTableLogger as ImageToTextLogger
================================================
FILE: lab05/text_recognizer/callbacks/imtotext.py
================================================
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
from .util import check_and_warn
class ImageToTextTableLogger(pl.Callback):
"""Logs the inputs and outputs of an image-to-text model to Weights & Biases."""
def __init__(self, max_images_to_log=32, on_train=True):
super().__init__()
self.max_images_to_log = min(max(max_images_to_log, 1), 32)
self.on_train = on_train
self._required_keys = ["gt_strs", "pred_strs"]
@rank_zero_only
def on_train_batch_end(self, trainer, module, output, batch, batch_idx):
if self.on_train:
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_table", "image-to-text table"):
return
else:
self._log_image_text_table(trainer, output, batch, "train/predictions")
@rank_zero_only
def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_table", "image-to-text table"):
return
else:
self._log_image_text_table(trainer, output, batch, "validation/predictions")
def _log_image_text_table(self, trainer, output, batch, key):
xs, _ = batch
gt_strs = output["gt_strs"]
pred_strs = output["pred_strs"]
mx = self.max_images_to_log
xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx]
xs = [wandb.Image(x) for x in xs]
rows = zip(*[xs, gt_strs, pred_strs])
columns = ["input_image", "ground_truth_string", "predicted_string"]
trainer.logger.log_table(key=key, columns=columns, data=list(rows))
def has_metrics(self, output):
return all(key in output.keys() for key in self._required_keys)
class ImageToTextCaptionLogger(pl.Callback):
"""Logs the inputs and outputs of an image-to-text model to Weights & Biases."""
def __init__(self, max_images_to_log=32, on_train=True):
super().__init__()
self.max_images_to_log = min(max(max_images_to_log, 1), 32)
self.on_train = on_train
self._required_keys = ["gt_strs", "pred_strs"]
@rank_zero_only
def on_train_batch_end(self, trainer, module, output, batch, batch_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "train/predictions")
@rank_zero_only
def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "validation/predictions")
@rank_zero_only
def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "test/predictions")
def _log_image_text_caption(self, trainer, output, batch, key):
xs, _ = batch
gt_strs = output["gt_strs"]
pred_strs = output["pred_strs"]
mx = self.max_images_to_log
xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx]
trainer.logger.log_image(key, xs, caption=pred_strs)
def has_metrics(self, output):
return all(key in output.keys() for key in self._required_keys)
================================================
FILE: lab05/text_recognizer/callbacks/model.py
================================================
import os
from pathlib import Path
import tempfile
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import torch
from .util import check_and_warn, logging
try:
import torchviz
has_torchviz = True
except ImportError:
has_torchviz = False
class ModelSizeLogger(pl.Callback):
"""Logs information about model size (in parameters and on disk)."""
def __init__(self, print_size=True):
super().__init__()
self.print_size = print_size
@rank_zero_only
def on_fit_start(self, trainer, module):
self._run(trainer, module)
def _run(self, trainer, module):
metrics = {}
metrics["mb_disk"] = self.get_model_disksize(module)
metrics["nparams"] = count_params(module)
if self.print_size:
print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB")
metrics = {f"size/{key}": value for key, value in metrics.items()}
trainer.logger.log_metrics(metrics, step=-1)
@staticmethod
def get_model_disksize(module):
"""Determine the model's size on disk by saving it to disk."""
with tempfile.NamedTemporaryFile() as f:
torch.save(module.state_dict(), f)
size_mb = os.path.getsize(f.name) / 1e6
return size_mb
class GraphLogger(pl.Callback):
"""Logs a compute graph as an image."""
def __init__(self, output_key="logits"):
super().__init__()
self.graph_logged = False
self.output_key = output_key
if not has_torchviz:
raise ImportError("GraphLogCallback requires torchviz." "")
@rank_zero_only
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx):
if not self.graph_logged:
try:
outputs = outputs[0][0]["extra"]
self.log_graph(trainer, module, outputs[self.output_key])
except KeyError:
logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}")
self.graph_logged = True
@staticmethod
def log_graph(trainer, module, outputs):
if check_and_warn(trainer.logger, "log_image", "graph"):
return
params_dict = dict(list(module.named_parameters()))
graph = torchviz.make_dot(outputs, params=params_dict)
graph.format = "png"
fname = Path(trainer.logger.experiment.dir) / "graph"
graph.render(fname)
fname = str(fname.with_suffix("." + graph.format))
trainer.logger.log_image(key="graph", images=[fname])
def count_params(module):
"""Counts the number of parameters in a Torch Module."""
return sum(p.numel() for p in module.parameters())
================================================
FILE: lab05/text_recognizer/callbacks/optim.py
================================================
import pytorch_lightning as pl
KEY = "optimizer"
class LearningRateMonitor(pl.callbacks.LearningRateMonitor):
"""Extends Lightning's LearningRateMonitor with a prefix.
Logs the learning rate during training. See the docs for
pl.callbacks.LearningRateMonitor for details.
"""
def _add_prefix(self, *args, **kwargs) -> str:
return f"{KEY}/" + super()._add_prefix(*args, **kwargs)
================================================
FILE: lab05/text_recognizer/callbacks/util.py
================================================
import logging
logging.basicConfig(level=logging.WARNING)
def check_and_warn(logger, attribute, feature):
if not hasattr(logger, attribute):
warn_no_attribute(feature, attribute)
return True
def warn_no_attribute(blocked_feature, missing_attribute):
logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.")
================================================
FILE: lab05/text_recognizer/data/__init__.py
================================================
"""Module containing submodules for each dataset.
Each dataset is defined as a class in that submodule.
The datasets should have a .config method that returns
any configuration information needed by the model.
Most datasets define their constants in a submodule
of the metadata module that is parallel to this one in the
hierarchy.
"""
from .util import BaseDataset
from .base_data_module import BaseDataModule
from .mnist import MNIST
from .emnist import EMNIST
from .emnist_lines import EMNISTLines
from .iam_paragraphs import IAMParagraphs
from .iam_lines import IAMLines
from .fake_images import FakeImageData
================================================
FILE: lab05/text_recognizer/data/base_data_module.py
================================================
"""Base DataModule class."""
import argparse
import os
from pathlib import Path
from typing import Collection, Dict, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch.utils.data import ConcatDataset, DataLoader
from text_recognizer import util
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.shared as metadata
def load_and_print_info(data_module_class) -> None:
"""Load EMNISTLines and print info."""
parser = argparse.ArgumentParser()
data_module_class.add_to_argparse(parser)
args = parser.parse_args()
dataset = data_module_class(args)
dataset.prepare_data()
dataset.setup()
print(dataset)
def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return filename
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename
BATCH_SIZE = 128
NUM_AVAIL_CPUS = len(os.sched_getaffinity(0))
NUM_AVAIL_GPUS = torch.cuda.device_count()
# sensible multiprocessing defaults: at most one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS
# but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS
class BaseDataModule(pl.LightningDataModule):
"""Base for all of our LightningDataModules.
Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html
"""
def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS)
self.on_gpu = isinstance(self.args.get("gpus", None), (str, int))
# Make sure to set the variables below in subclasses
self.input_dims: Tuple[int, ...]
self.output_dims: Tuple[int, ...]
self.mapping: Collection
self.data_train: Union[BaseDataset, ConcatDataset]
self.data_val: Union[BaseDataset, ConcatDataset]
self.data_test: Union[BaseDataset, ConcatDataset]
@classmethod
def data_dirname(cls):
return metadata.DATA_DIRNAME
@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size",
type=int,
default=BATCH_SIZE,
help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.",
)
parser.add_argument(
"--num_workers",
type=int,
default=DEFAULT_NUM_WORKERS,
help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.",
)
return parser
def config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping}
def prepare_data(self, *args, **kwargs) -> None:
"""Take the first steps to prepare data for use.
Use this method to do things that might write to disk or that need to be done only from a single GPU
in distributed settings (so don't set state `self.x = y`).
"""
def setup(self, stage: Optional[str] = None) -> None:
"""Perform final setup to prepare data for consumption by DataLoader.
Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
def train_dataloader(self):
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def val_dataloader(self):
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def test_dataloader(self):
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
================================================
FILE: lab05/text_recognizer/data/emnist.py
================================================
"""EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present."""
import json
import os
from pathlib import Path
import shutil
from typing import Sequence
import zipfile
import h5py
import numpy as np
import toml
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset, split_dataset
import text_recognizer.metadata.emnist as metadata
from text_recognizer.stems.image import ImageStem
from text_recognizer.util import temporary_working_directory
NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS
RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME
METADATA_FILENAME = metadata.METADATA_FILENAME
DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class.
TRAIN_FRAC = 0.8
class EMNIST(BaseDataModule):
"""EMNIST dataset of handwritten characters and digits.
"The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
def __init__(self, args=None):
super().__init__(args)
self.mapping = metadata.MAPPING
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
self.transform = ImageStem()
self.input_dims = metadata.DIMS
self.output_dims = metadata.OUTPUT_DIMS
def prepare_data(self, *args, **kwargs) -> None:
if not os.path.exists(PROCESSED_DATA_FILENAME):
_download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_trainval = f["x_train"][:]
self.y_trainval = f["y_train"][:].squeeze().astype(int)
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self):
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n"
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
def _download_and_process_emnist():
metadata = toml.load(METADATA_FILENAME)
_download_raw_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path):
print("Unzipping EMNIST...")
with temporary_working_directory(dirname):
with zipfile.ZipFile(filename, "r") as zf:
zf.extract("matlab/emnist-byclass.mat")
from scipy.io import loadmat
# NOTE: If importing at the top of module, would need to list scipy as prod dependency.
print("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
# NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices
if SAMPLE_TO_BALANCE:
print("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
print("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
print("Saving essential dataset parameters to text_recognizer/data...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(list(mapping.values()))
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
with open(ESSENTIALS_FILENAME, "w") as f:
json.dump(essentials, f)
print("Cleaning up...")
shutil.rmtree("matlab")
def _sample_to_balance(x, y):
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(42)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_inds = []
for label in np.unique(y.flatten()):
inds = np.where(y == label)[0]
sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
all_sampled_inds.append(sampled_inds)
ind = np.concatenate(all_sampled_inds)
x_sampled = x[ind]
y_sampled = y[ind]
return x_sampled, y_sampled
def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset
iam_characters = [
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
# Also add special tokens:
# - CTC blank token at index 0
# - Start token at index 1
# - End token at index 2
# - Padding token at index 3
# NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this!
return ["", "
", "", " and ", *tokens, " and ", *tokens, ""]
self.end_index = self.inverse_mapping["",
""]
self.end_token = inverse_mapping[""]
self.end_token = inverse_mapping[""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab06/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab06/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 03: Transformers and Paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The fundamental reasons why the Transformer is such\n",
"a powerful and popular architecture\n",
"- Core intuitions for the behavior of Transformer architectures\n",
"- How to use a convolutional encoder and a Transformer decoder to recognize\n",
"entire paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 3\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Transformers?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal in building a text recognizer is to take a two-dimensional image\n",
"and convert it into a one-dimensional sequence of characters\n",
"from some alphabet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional neural networks,\n",
"discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n",
"are great at encoding images,\n",
"taking them from their raw pixel values\n",
"to a more semantically meaningful numerical representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But how do we go from that to a sequence of letters?\n",
"And what's especially tricky:\n",
"the number of letters in an image is separable from its size.\n",
"A screenshot of this document has a much higher density of letters\n",
"than a close-up photograph of a piece of paper.\n",
"How do we get a _variable-length_ sequence of letters,\n",
"where the length need have nothing to do with the size of the input tensor?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n",
"they were\n",
"[originally introduced](https://arxiv.org/abs/1706.03762)\n",
"for transforming one sequence into another,\n",
"as in machine translation.\n",
"This makes them a natural fit for processing language.\n",
"\n",
"But they have also found success in other domains --\n",
"at the time of this writing, large transformers\n",
"dominate the\n",
"[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n",
"that has become a de facto standard for comparing models\n",
"and are finding\n",
"[application in reinforcement learning](https://arxiv.org/abs/2106.01345)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we will use a Transformer as a key component of our final architecture:\n",
"we will encode our input images with a CNN\n",
"and then read them out into a text sequence with a Transformer.\n",
"\n",
"Before trying out this new model,\n",
"let's first get an understanding of why the Transformer architecture\n",
"has become so popular by walking through its history\n",
"and then get some intuition for how it works\n",
"by looking at some\n",
"[recent work](https://transformer-circuits.pub/)\n",
"on explaining the behavior of both toy models and state-of-the-art language models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kmKqjbvd-Mj3"
},
"source": [
"## Why not convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRqkUMdM-OxU"
},
"source": [
"In the ancient beforetimes (i.e. 2016),\n",
"the best models for natural language processing were all\n",
"_recurrent_ neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional networks were also occasionally used,\n",
"but they suffered from a serious issue:\n",
"their architectural biases don't fit text.\n",
"\n",
"First, _translation equivariance_ no longer holds.\n",
"The beginning of a piece of text is often quite different from the middle,\n",
"so the absolute position matters.\n",
"\n",
"Second, _locality_ is not as important in language.\n",
"The name of a character that hasn't appeared in thousands of pages\n",
"can become salient when someone asks, \"Whatever happened to\n",
"[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n",
"\n",
"Consider interpreting a piece of text like the Python code below:\n",
"```python\n",
"def do(arg1, arg2, arg3):\n",
" a = arg1 + arg2\n",
" b = arg3[:3]\n",
" c = a * b\n",
" return c\n",
"\n",
"print(do(1, 1, \"ayy lmao\"))\n",
"```\n",
"\n",
"After a `(` we expect a `)`,\n",
"but possibly very long afterwards,\n",
"[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n",
"and similarly we expect a `]` at some point after a `[`.\n",
"\n",
"For translation variance, consider\n",
"that we interpret `*` not by\n",
"comparing it to its neighbors\n",
"but by looking at `a` and `b`.\n",
"We mix knowledge learned through experience\n",
"with new facts learned while reading --\n",
"also known as _in-context learning_.\n",
"\n",
"In a longer text,\n",
"[e.g. the one you are reading now](./lab03_transformers.ipynb),\n",
"the translation variance of text is clearer.\n",
"Every lab notebook begins with the same header,\n",
"setting up the environment,\n",
"but that header never appears elsewhere in the notebook.\n",
"Later positions need to be processed in terms of the previous entries.\n",
"\n",
"Unlike an image, we cannot simply rotate or translate our \"camera\"\n",
"and get a new valid text.\n",
"[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n",
"that can be read without regard to position."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The field of formal language theory,\n",
"which has deep mutual influence with computer science,\n",
"gives one way of explaining the issues with convolutional networks:\n",
"they can only understand languages with _finite contexts_,\n",
"where all the information can be found within a finite window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The immediate solution, drawing from the connections to computer science, is\n",
"[recursion](https://www.google.com/search?q=recursion).\n",
"A network whose output on the final entry of the sequence is a recursive function\n",
"of all the previous entries can build up knowledge\n",
"as it reads the sequence and treat early entries quite differently than it does late ones."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa6cbTlImkEh"
},
"source": [
"In pseudo-code, such a _recurrent neural network_ module might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lKtBoPnglPrW"
},
"source": [
"```python\n",
"def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n",
" next_inputs = input_module(xs[-1])\n",
" next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n",
" return output_module(next_inputs, next_hiddens)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbJPSMnEm516"
},
"source": [
"If you've had formal computer science training,\n",
"then you may be familiar with the power of recursion,\n",
"e.g. the\n",
"[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n",
"that gave its name to the now much better-known\n",
"[startup incubator](https://www.ycombinator.com/).\n",
"\n",
"The particular form of recursion used by\n",
"recurrent neural networks implements a\n",
"[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n",
"\n",
"> If you've know a lot of computer science,\n",
"you might be concerned by this connection.\n",
"What about other\n",
"[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n",
"Where are the neural network architectures for differentiable\n",
"[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n",
"Check out Graph Neural Networks,\n",
"[which implement dynamic programming](https://arxiv.org/abs/2203.15544)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "63mMTbEBpVuE"
},
"source": [
"Recurrent networks are able to achieve\n",
"[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n",
"\n",
"There are many popular recurrent architectures,\n",
"from the beefy and classic\n",
"[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n",
"and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n",
"([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n",
"all of which have roughly similar capabilities but\n",
"[some of which are easier to train](https://arxiv.org/abs/1611.09913)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwQHVTIslOku"
},
"source": [
"In the same sense that MLPs can model \"any\" feedforward function,\n",
"in principle even basic RNNs\n",
"[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n",
"\n",
"In particular they can model any\n",
"[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n",
"which is a formal way of saying that they can in principle\n",
"do anything a computer is capable of doing.\n",
"\n",
"The question is then..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3J8EoGN3pu7P"
},
"source": [
"## Why aren't we all using RNNs?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDwNWaevpt_3"
},
"source": [
"The guarantees that MLPs can model any function\n",
"or that RNNs can model Turing machines\n",
"provide decent intuition but are not directly practically useful.\n",
"Among other reasons, they don't guarantee learnability --\n",
"that starting from random parameters we can find the parameters\n",
"that implement a given function.\n",
"The\n",
"[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n",
"than would seem from basic theoretical and empirical analysis.\n",
"\n",
"One way of understanding capacity to model language is\n",
"[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n",
"In this model of formal languages,\n",
"Turing machines sit at the top\n",
"([practically speaking](https://arxiv.org/abs/math/0209332)).\n",
"\n",
"With better mathematical models,\n",
"RNNs and LSTMs can be shown to be\n",
"[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n",
"with RNNs looking more like\n",
"[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n",
"and LSTMs coming in\n",
"[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n",
"\n",
"More controversially:\n",
"the Chomsky hierarchy is great for understanding syntax and grammar,\n",
"which makes it great for building parsers\n",
"and working with formal languages,\n",
"but the goal in _natural_ language processing is to understand _natural_ language.\n",
"Most humans' natural language is far from strictly grammatical,\n",
"but that doesn't mean it is nonsense.\n",
"\n",
"And to really \"understand\" language means\n",
"to understand its semantic content, which is fuzzy.\n",
"The most important thing for handling the fuzzy semantic content\n",
"of language is not whether you can recall\n",
"[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n",
"but whether you can model probabilistic relationships between concepts\n",
"in addition to grammar and syntax."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These both leave theoretical room for improvement over current recurrent\n",
"language and sequence models.\n",
"\n",
"But the real cause of the rise of Transformers is that..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dsu1ebvAp-3Z"
},
"source": [
"## Transformers are designed to train fast at scale on contemporary hardware."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4abU5adsPGs"
},
"source": [
"The Transformer architecture has several important features,\n",
"discussed below,\n",
"but one of the most important reasons why it is successful\n",
"is because it can be more easily trained at scale.\n",
"\n",
"This scalability is the focus of the discussion in the paper\n",
"that introduced the architecture,\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"and\n",
"[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n",
"\n",
"The recursion in RNNs is inherently sequential:\n",
"the dependence on the outputs from earlier in the sequence\n",
"means computations within an example cannot be parallelized.\n",
"\n",
"So RNNs must batch across examples to scale,\n",
"but as sequence length grows this hits memorybandwidth limits.\n",
"Serving up large batches quickly with good randomness guarantees\n",
"is also hard to optimize,\n",
"especially in distributed settings.\n",
"\n",
"The Transformer architecture,\n",
"on the other hand,\n",
"can be readily parallelized within a single example sequence,\n",
"in addition to parallelization across batches.\n",
"This can lead to massive performance gains for a fixed scale,\n",
"which means larger, higher capacity models\n",
"can be trained on larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Mzk2haFC_G1"
},
"source": [
"How does the architecture achieve this parallelizability?\n",
"\n",
"Let's start with the architecture diagram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u59eu4snLQfp"
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n",
"\n",
"display.Image(url=base_url + \"/aiayn-figure-1.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ez-XEQ7M0UlR"
},
"source": [
"> To head off a bit of confusion\n",
" in case you've worked with Transformer architectures before:\n",
" the original \"Transformer\" is an encoder/decoder architecture.\n",
" Many LLMs, like GPT models, are decoder only,\n",
" because this has turned out to scale well,\n",
" and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n",
" it's all text anyways.\n",
" We, however, will be using them across modalities,\n",
" so we need an explicit encoder,\n",
" as above. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ok4ksBi4vp89"
},
"source": [
"First focusing on the encoder (left):\n",
"the encoding at a given position is a function of all previous inputs.\n",
"But it is not a function of the previous _encodings_:\n",
"we produce the encodings \"all at once\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RPN7C-_OqzHP"
},
"source": [
"The decoder (right) does use previous \"outputs\" as its inputs,\n",
"but those outputs are not the vectors of layer activations\n",
"(aka embeddings)\n",
"that are produced by the network.\n",
"They are instead the processed outputs,\n",
"after a `softmax` and an `argmax`.\n",
"\n",
"We could obtain these outputs by processing the embeddings,\n",
"much like in a recurrent architecture.\n",
"In fact, that is one way that Transformers are run.\n",
"It's what happens in the `.forward` method\n",
"of the model we'll be training for character recognition:\n",
"`ResnetTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L5_2WMmtDnJn"
},
"source": [
"Let's look at that forward method\n",
"and connect it to the diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FR5pk4kEyCGg"
},
"outputs": [],
"source": [
"from text_recognizer.models import ResnetTransformer\n",
"\n",
"\n",
"ResnetTransformer.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-J5UFDoPzPbq"
},
"source": [
"`.encode` happens first -- that's the left side of diagram.\n",
"\n",
"The encoder can in principle be anything\n",
"that produces a sequence of fixed-length vectors,\n",
"but here it's\n",
"[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n",
"\n",
"Then we start iterating over the sequence\n",
"in the `for` loop.\n",
"\n",
"Focus on the first few lines of code.\n",
"We apply `.decode` (right side of diagram)\n",
"to the outputs so far.\n",
"\n",
"Once we have a new `output`, we apply `.argmax`\n",
"to turn the logits into a concrete prediction of\n",
"a particular token.\n",
"\n",
"This is added as the last output token\n",
"and then the loop happens again."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTcy8-rV1dHr"
},
"source": [
"Run this way, our model looks very much like a recurrent architecture:\n",
"we call the model on its own outputs\n",
"to generate the next value.\n",
"These types of models are also referred to as\n",
"[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n",
"because we predict (as we do in _regression_)\n",
"the next value based on our own (_auto_) output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But Transformers are designed to be _trained_ more scalably than RNNs,\n",
"not necessarily to _run inference_ more scalably,\n",
"and it's actually not the case that our model's `.forward` is called during training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCxMSAWmEKBt"
},
"source": [
"Let's look at what happens during training\n",
"by checking the `training_step`\n",
"of the `LightningModule`\n",
"we use to train our Transformer models,\n",
"the `TransformerLitModel`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o7q8N7P2w4H"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import TransformerLitModel\n",
"\n",
"TransformerLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VgNNOjvzC4y"
},
"source": [
"Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tz-6NGPR4dUr"
},
"source": [
"Let's look at `.teacher_forward`,\n",
"and in particular its type signature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ILc2oWET4i2Z"
},
"outputs": [],
"source": [
"TransformerLitModel.teacher_forward??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lf32lpgrDb__"
},
"source": [
"This is known as \"teacher forcing\".\n",
"The \"teacher\" signal is \"forcing\"\n",
"the model to behave as though\n",
"it got the answer right.\n",
"\n",
"[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n",
"It's more effective here\n",
"because the right teaching signal\n",
"for our network is the target data,\n",
"which we have access to during training,\n",
"whereas in an RNN the best teaching signal\n",
"would be the target embedding vector,\n",
"which we do not know.\n",
"\n",
"During inference, when we don't have access to the ground truth,\n",
"we revert to the autoregressive `.forward` method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This \"trick\" allows Transformer architectures to readily scale\n",
"up models to the parameter counts\n",
"[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAjqpJm9uUuU"
},
"source": [
"## Is there more to Transformers more than just a training trick?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kWCYXeHv7Qc9"
},
"source": [
"[Very](https://arxiv.org/abs/2005.14165),\n",
"[very](https://arxiv.org/abs/1909.08053),\n",
"[very](https://arxiv.org/abs/2205.01068)\n",
"large Transformer models have powered the most recent wave of exciting results in ML, like\n",
"[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n",
"\n",
"They are also the first machine learning models to have come anywhere close to\n",
"deserving the term _artificial intelligence_ --\n",
"a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is surprising because the models and their training procedure are\n",
"(relatively speaking)\n",
"pretty _simple_,\n",
"even if it doesn't feel that way on first pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The basic Transformer architecture is just a bunch of\n",
"dense matrix multiplications and non-linearities --\n",
"it's perhaps simpler than a convolutional architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And advances since the introduction of Transformers in 2017\n",
"have not in the main been made by\n",
"creating more sophisticated model architectures\n",
"but by increasing the scale of the base architecture,\n",
"or if anything making it simpler, as in\n",
"[GPT-type models](https://arxiv.org/abs/2005.14165),\n",
"which drop the encoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V1HQS9ey8GMc"
},
"source": [
"These models are also trained on very simple tasks:\n",
"most LLMs are just trying to predict the next element in the sequence,\n",
"given the previous elements --\n",
"a task simple enough that Claude Shannon,\n",
"father of information theory, was\n",
"[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n",
"\n",
"These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n",
"e.g. by scraping the web."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They are also trained in a simple fashion:\n",
"first-order stochastic optimizers, like SGD or an\n",
"[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n",
"intended for the most basic of optimization problems,\n",
"that scale more readily than the second-order optimizers\n",
"that dominate other areas of optimization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kz9HPDoy7OAl"
},
"source": [
"This is\n",
"[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n",
"of work in ML:\n",
"simple, even seemingly wasteful,\n",
"architectures that scale well and are robust\n",
"to implementation details\n",
"eventually outstrip more clever but\n",
"also more finicky approaches that are harder to scale.\n",
"This lesson has led some to declare that\n",
"[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n",
"in machine learning, and perhaps even in artificial intelligence."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdN9o2Y771YZ"
},
"source": [
"> That is not to say that because the algorithms are relatively simple,\n",
" training a model at this scale is _easy_ --\n",
" [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n",
" [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n",
" [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n",
" But choosing the simplest algorithm at every step makes solving the scaling problem feasible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baVGf6gKFOvs"
},
"source": [
"The importance of scale is the key lesson from the Transformer architecture,\n",
"far more than any theoretical considerations\n",
"or any of the implementation details.\n",
"\n",
"That said, these large Transformer models are capable of\n",
"impressive behaviors and understanding how they achieve them\n",
"is of intellectual interest.\n",
"Furthermore, like any architecture,\n",
"there are common failure modes,\n",
"of the model and of the modelers who use them,\n",
"that need to be taken into account."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"Below, we'll cover two key intuitions about Transformers:\n",
"Transformers are _residual_, like ResNets,\n",
"and they compose _low rank_ sequence transformations.\n",
"Together, this means they act somewhat like a computer,\n",
"reading from and writing to a \"tape\" or memory\n",
"with a sequence of simple instructions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"We'll also cover a surprising implementation detail:\n",
"despite being commonly used for sequence modeling,\n",
"by default the architecture is _position insensitive_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uni0VTCr9lev"
},
"source": [
"### Intuition #1: Transformers are highly residual."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0MoBt-JLJz-d"
},
"source": [
"> The discussion of these inuitions summarizes the discussion in\n",
"[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n",
"from\n",
"[Anthropic](https://www.anthropic.com/),\n",
"an AI safety and research company.\n",
"The figures below are from that blog post.\n",
"It is the spiritual successor to the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"covered in\n",
"[Lab 02b](https://lab02b-colab).\n",
"If you want to truly understand Transformers,\n",
"we highly recommend you check it out,\n",
"including the\n",
"[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UUbNVvM5Ferm"
},
"source": [
"It's easy to see that ResNets are residual --\n",
"it's in the name, after all.\n",
"\n",
"But Transformers are,\n",
"in some sense,\n",
"even more closely tied to residual computation\n",
"than are ResNets:\n",
"ResNets and related architectures include downsampling,\n",
"so there is not a direct path from inputs to outputs.\n",
"\n",
"In Transformers, the exact same shape is maintained\n",
"from the moment tokens are embedded,\n",
"through dozens or hundreds of intermediate layers,\n",
"and until they are \"unembedded\" into class logits.\n",
"The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n",
"\n",
"The resiudal stream is easy to see with a change of perspective.\n",
"Instead of the usual architecture diagram above,\n",
"which emphasizes the layers acting on the tensors,\n",
"consider this alternative view,\n",
"which emphasizes the tensors as they pass through the layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRMlVguKKW6y"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-residual-view.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9K3N7ilVkB3"
},
"source": [
"For definitions of variables and terms, see the\n",
"[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arvciE-kKd_L"
},
"source": [
"Note that this is a _decoder-only_ Transformer architecture --\n",
"so it should be compared with the right-hand side of the original architecture diagram above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvrRMd_RKp_G"
},
"source": [
"Notice that outputs of the attention blocks \n",
"and of the MLP layers are\n",
"added to their inputs, as in a ResNet.\n",
"These operations are represented as \"Add & Norm\" layers in the classical diagram;\n",
"normalization is ignored here for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8n_iT-FFAbK"
},
"source": [
"This total commitment to residual operations\n",
"means the size of the embeddings\n",
"(referred to as the \"model dimension\" or the \"embedding dimension\",\n",
"here and below `d_model`)\n",
"stays the same throughout the entire network.\n",
"\n",
"That means, for example,\n",
"that the output of each layer can be used as input to the \"unembedding\" layer\n",
"that produces logits.\n",
"We can read out the computations of intermediate layers\n",
"just by passing them through the unembedding layer\n",
"and examining the logit tensor.\n",
"See\n",
"[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n",
"for detailed experiments and interactive notebooks.\n",
"\n",
"In short, we observe a sort of \"progressive refinement\"\n",
"of the next-token prediction\n",
"as the embeddings proceed, depthwise, through the network."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ovh_3YgY9z2h"
},
"source": [
"### Intuition #2 Transformer heads learn low rank transformations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpNmozlnOdPC"
},
"source": [
"In the original paper and in\n",
"most presentations of Transformers,\n",
"the attention layer is written like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PA7me8gNP5LE"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pseudo-typed PyTorch (based loosely on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n",
"that looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oeict_6wGJgD"
},
"source": [
"```python\n",
"def classic_attention(\n",
" Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" K: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" return torch.softmax(Q @ K.T) @ V\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8pewU90DSuOR"
},
"source": [
"This is effectively exactly\n",
"how it is written\n",
"in PyTorch,\n",
"apart from implementation details\n",
"(look for `bmm` for the matrix multiplications and a `softmax` call):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrgTpKFvOhwc"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"F._scaled_dot_product_attention??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebDXZ0tlSe7g"
},
"source": [
"But the best way to write an operation so that a computer can execute it quickly\n",
"is not necessarily the best way to write it so that a human can understand it --\n",
"otherwise we'd all be coding in assembly.\n",
"\n",
"And this is a strange way to write it --\n",
"you'll notice that what we normally think of\n",
"as the \"inputs\" to the layer are not shown.\n",
"\n",
"We can instead write out the attention layer\n",
"as a function of the inputs $x$.\n",
"We write it for a single \"attention head\".\n",
"Each attention layer includes a number of heads\n",
"that read and write from the residual stream\n",
"simultaneously and independently.\n",
"We also add the output layer weights $W_O$\n",
"and we get:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LuFNR67tQpsf"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVnBjjfOLwxP"
},
"source": [
"or, in pseudo-typed PyTorch:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmpOm-HfGaNz"
},
"source": [
"```python\n",
"def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n",
" key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n",
" key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n",
" # maps queries of residual stream to keys from residual stream, independent of position\n",
"\n",
" value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n",
" output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n",
" value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n",
" # transformation applied to each token, regardless of position\n",
"\n",
" attention_logits = x.T @ key_query_circuit @ x\n",
" attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n",
" # maps positions to positions, often very sparse\n",
"\n",
" value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n",
"\n",
" return attention_map @ value_output # transformed tokens filtered by attention map\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dC0eqxZ6UAGT"
},
"source": [
"Consider the `key_query_circuit`\n",
"and `value_output_circuit`\n",
"matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n",
"\n",
"The key/query dimension, `d_head`\n",
"is small relative to the model's dimension, `d_model`,\n",
"so $W_{QK}$ and $W_{OV}$ are very low rank,\n",
"[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n",
"that they factorize into two matrices,\n",
"one with a smaller number of rows\n",
"and another with a smaller number of columns.\n",
"That number is called the _rank_.\n",
"\n",
"When computing, these matrices are better represented via their components,\n",
"rather than computed directly,\n",
"which leads to the normal implementation of attention.\n",
"\n",
"In a large language model,\n",
"the ratio of residual stream dimension, `d_model`, to\n",
"the dimension of a single head, `d_head`, is huge, often 100:1.\n",
"That means each query, key, and value computed at a position\n",
"is a fairly simple, low-dimensional feature of the residual stream at that position.\n",
"\n",
"For visual intuition,\n",
"we compare what a matrix with a rank 100th of full rank looks like,\n",
"relative to a full rank matrix of the same size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LUbojJMiW2C"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"\n",
"low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n",
"full_rank = torch.randn(100, 100)\n",
"plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n",
"plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqBst92-OVka"
},
"source": [
"The pattern in the first matrix is very simple,\n",
"relative to the pattern in the second matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkCGrs9EiVh4"
},
"source": [
"Another feature of low rank transformations is\n",
"that they have a large nullspace or kernel --\n",
"these are directions we can move the input without changing the output.\n",
"\n",
"That means that many changes to the residual stream won't affect the behavior of this head at all."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVz2dQgzhD4p"
},
"source": [
"### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVlzwR03m8mC"
},
"source": [
"The combination of residuality\n",
"(changes are added to the current input)\n",
"and low rank\n",
"(only a small subspace is changed by each head)\n",
"drastically changes the intuition about Transformers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qqjZI2jKe6HH"
},
"source": [
"Rather than being an \"embedding of a token in its context\",\n",
"the residual stream becomes something more like a memory or a scratchpad:\n",
"one layer reads a small bit of information from the stream\n",
"and writes a small bit of information back to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YIBkxlqepjc"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-layer-residual.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RtsKhkLfk00l"
},
"source": [
"The residual stream works like a memory because it is roomy enough\n",
"that these actions need not interfere:\n",
"the subspaces targeted by reads and writes are small relative to the ambient space,\n",
"so they can\n",
"\n",
"Additionally, the dimension of each head is still in the 100s in large models,\n",
"and\n",
"[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n",
"in them, so the number of effectively degrees of freedom is\n",
"actually larger than the dimension.\n",
"This phenomenon allows high-dimensional tensors to serve as\n",
"[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n",
"There are\n",
"[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n",
"\n",
"Together, this means an early layer can write information to the stream\n",
"that can be used by later layers -- by many of them at once, possibly much later.\n",
"Later layers can learn to edit this information,\n",
"e.g. deleting it,\n",
"if doing so reduces the loss,\n",
"but by default the information is preserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EragIygzJg86"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-stream-read-write.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKIaUZjwkpW7"
},
"source": [
"Lastly, the softmax in the attention has a sparsifying effect,\n",
"and so many attention heads are reading from \n",
"just one token and writing to just one other token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dN6VcJqIMKnB"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-token-to-token.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Repeatedly reading information from an external memory\n",
"and using it to decide which operation to perform\n",
"and where to write the results\n",
"is at the core of the\n",
"[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n",
"For a concrete example, the\n",
"[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n",
"includes a dissection of a form of \"pointer arithmetic\"\n",
"that appears in some models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kLFh7Mvnolr"
},
"source": [
"This point of view seems\n",
"very promising for explaining numerous\n",
"otherwise perhaps counterintuitive features of Transformer models.\n",
"\n",
"- This framework predicts lots that Transformers will readily copy-and-paste information,\n",
"which might explain phenomena like\n",
"[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n",
"\n",
"- It also readily explains\n",
"[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n",
"an important component of why Transformers perform well on medium-length texts\n",
"and in few-shot learning.\n",
"\n",
"- Transformers also perform better on reasoning tasks when the text\n",
"[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n",
"is added to their input prompt.\n",
"This is partly due to the fact that that prompt is associated,\n",
"in the dataset, with clearer reasoning,\n",
"and since the models are trained to predict which tokens tend to appear\n",
"after an input, they tend to produce better reasoning with that prompt --\n",
"an explanation purely in terms of sequence modeling.\n",
"But it also gives the Transformer license to generate a large number of tokens\n",
"that act to store intermediate information,\n",
"making for a richer residual stream\n",
"for reading and writing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RyLRzgG-93yB"
},
"source": [
"### Implementation detail: Transformers are position-insensitive by default."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oR6PnrlA_hJ2"
},
"source": [
"In the attention calculation\n",
"each token can query each other token,\n",
"with no regard for order.\n",
"Furthermore, the construction of queries, keys, and values\n",
"is based on the content of the embedding vector,\n",
"which does not automatically include its position.\n",
"\"dog bites man\" and \"man bites dog\" are identical, as in\n",
"[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n",
"\n",
"For most sequences,\n",
"this is unacceptable:\n",
"absolute and relative position matter\n",
"and we cannot use the future to predict the past.\n",
"\n",
"We need to add two pieces to get a Transformer architecture that's usable for next-token prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWHxGJz2-6ZK"
},
"source": [
"First, the simpler piece:\n",
"\"causal\" attention,\n",
"so-named because it ensures that values earlier in the sequence\n",
"are not influenced by later values, which would\n",
"[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0c42xi6URYB4"
},
"source": [
"The most common solution is straightforward:\n",
"we calculate attention between all tokens,\n",
"then throw out non-causal values by \"masking\" them\n",
"(this is before applying the softmax,\n",
"so masking means adding $-\\infty$).\n",
"\n",
"This feels wasteful --\n",
"why are we calculating values we don't need?\n",
"Trying to be smarter would be harder,\n",
"and might rely on operations that aren't as optimized as\n",
"matrix multiplication and addition.\n",
"Furthermore, it's \"only\" twice as many operations,\n",
"so it doesn't even show up in $O$-notation.\n",
"\n",
"A sample attention mask generated by our code base is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NXaWe6pT-9jV"
},
"outputs": [],
"source": [
"from text_recognizer.models import transformer_util\n",
"\n",
"\n",
"attention_mask = transformer_util.generate_square_subsequent_mask(100)\n",
"\n",
"ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n",
"plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n",
"print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This solves our causality problem,\n",
"but we still don't have positional information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZamUE4WIoGS2"
},
"source": [
"The standard technique\n",
"is to add alternating sines and cosines\n",
"of increasing frequency to the embeddings\n",
"(there are\n",
"[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n",
"most notably\n",
"[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n",
"Each position in the sequence is then uniquely identifiable\n",
"from the pattern of these values.\n",
"\n",
"> Furthermore, for the same reason that\n",
" [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n",
" translations, e.g. relative positions, are fairly easy to express as linear transformations\n",
" of sines and cosines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IDG2uOsaELU0"
},
"source": [
"We superimpose this positional information on our embeddings.\n",
"Note that because the model is residual,\n",
"this position information will be by default preserved\n",
"as it passes through the network,\n",
"so it doesn't need to be repeatedly added."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what this positional encoding looks like in our codebase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Zk62Q-a-1Ax"
},
"outputs": [],
"source": [
"PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n",
"\n",
"pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n",
"\n",
"ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n",
"print(pe[:4, :8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep2ClIWvqDms"
},
"source": [
"When we add the positional information to our embeddings,\n",
"both the embedding information and the positional information\n",
"is approximately preserved,\n",
"as can be visually assessed below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJuFjoCzC0Y4"
},
"outputs": [],
"source": [
"fake_embeddings = torch.randn_like(pe) * 0.5\n",
"\n",
"ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n",
"\n",
"fake_embeddings_with_pe = fake_embeddings + pe\n",
"\n",
"plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHIzBxDkEmH8"
},
"source": [
"A [similar technique](https://arxiv.org/abs/2103.06450)\n",
"is used to also incorporate positional information into the image embeddings,\n",
"which are flattened before being fed to the decoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HC1N85wl8dvn"
},
"source": [
"### Learn more about Transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lJwYxkjTk15t"
},
"source": [
"We're only able to give a flavor and an intuition for Transformers here.\n",
"\n",
"To improve your grasp on the nuts and bolts, check out the\n",
"[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n",
"which is surprisingly approachable,\n",
"as far as ML research papers go.\n",
"The\n",
"[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n",
"adds code and commentary to the original paper,\n",
"which makes it even more digestible.\n",
"For something even friendlier, check out the\n",
"[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n",
"by Jay Alammar, which has an accompanying\n",
"[video](https://youtu.be/-QH8fRhqFHM).\n",
"\n",
"Anthropic's work on\n",
"[Transformer Circuits](https://transformer-circuits.pub/),\n",
"summarized above, has some of the best material\n",
"for building theoretical understanding\n",
"and is still being updated with extensions and applications of the framework.\n",
"The\n",
"[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n",
"are a great aid for checking and building your understanding.\n",
"\n",
"But they are fairly math-heavy.\n",
"If you have more of a software engineering background, see\n",
"Transformer Circuits co-author Nelson Elhage's blog post\n",
"[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n",
"\n",
"For a gentler introduction to the intuition for Transformers,\n",
"check out Brandon Rohrer's\n",
"[Transformers From Scratch](https://e2eml.school/transformers.html)\n",
"tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qg7zntJES-aT"
},
"source": [
"An aside:\n",
"the matrix multiplications inside attention dominate\n",
"the big-$O$ runtime of Transformers.\n",
"So trying to make the attention mechanism more efficient, e.g. linear time,\n",
"has generated a lot of research\n",
"(review paper\n",
"[here](https://arxiv.org/abs/2009.06732)).\n",
"Despite drawing a lot of attention, so to speak,\n",
"at the time of writing in mid-2022, these methods\n",
"[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n",
"so it isn't likely to be worth the effort to spend time learning about them\n",
"unless you are a Transformer specialist."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCjXysEJ8g9_"
},
"source": [
"# Using Transformers to read paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsfKWnOvqjva"
},
"source": [
"Our simple convolutional model for text recognition from\n",
"[Lab 02b](https://fsdl.me/lab02b-colab)\n",
"could only handle cleanly-separated characters.\n",
"\n",
"It worked by sliding a LeNet-style CNN\n",
"over the image,\n",
"predicting a character for each step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "njLdzBqy-I90"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist_lines = text_recognizer.data.EMNISTLines()\n",
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"\n",
"# for sliding, see the for loop over range(S)\n",
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0N6yDBQq8ns"
},
"source": [
"But unfortunately for us, handwritten text\n",
"doesn't come in neatly-separated characters\n",
"of equal size, so we trained our model on synthetic data\n",
"designed to work with that model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiqUVbj0sxLr"
},
"source": [
"Now that we have a better model,\n",
"we can work with better data:\n",
"paragraphs from the\n",
"[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oizsOAcKs-dD"
},
"source": [
"The cell uses our `LightningDataModule`\n",
"to download and preprocess this data,\n",
"writing results to disk.\n",
"We can then spin up `DataLoader`s to give us batches.\n",
"\n",
"It can take several minutes to run the first time\n",
"on commodity machines,\n",
"with most time spent extracting the data.\n",
"On subsequent runs,\n",
"the time-consuming operations will not be repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uL9LHbjdsUbm"
},
"outputs": [],
"source": [
"iam_paragraphs = text_recognizer.data.IAMParagraphs()\n",
"\n",
"iam_paragraphs.prepare_data()\n",
"iam_paragraphs.setup()\n",
"xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n",
"\n",
"iam_paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBkFN9bbTm_S"
},
"source": [
"Now that we've got a batch,\n",
"let's take a look at some samples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hqaps8yxtBhU"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import wandb\n",
"\n",
"\n",
"def show(y):\n",
" y = y.detach().cpu() # bring back from accelerator if it's being used\n",
" return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 04: Experiment Management"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How experiment management brings observability to ML model development\n",
"- Which features of experiment management we use in developing the Text Recognizer\n",
"- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 4\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lab contains a large number of embedded iframes\n",
"that benefit from having a wide window.\n",
"The cell below makes the notebook as wide as your browser window\n",
"if `full_width` is set to `True`.\n",
"Full width is the default behavior in Colab,\n",
"so this cell is intended to improve the viewing experience in other Jupyter environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Why experiment management?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To understand why we need experiment management for ML development,\n",
"let's start by running an experiment.\n",
"\n",
"We'll train a new model on a new dataset,\n",
"using the training script `training/run_experiment.py`\n",
"introduced in [Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use a CNN encoder and Transformer decoder, as in\n",
"[Lab 03](https://fsdl.me/lab03-colab),\n",
"but with some changes so we can iterate faster.\n",
"We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n",
"[Lab02b](https://fsdl.me/lab02b-colab),\n",
"and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n",
"from text_recognizer.data import IAMLines # processed version split into individual lines\n",
"from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n",
"\n",
"\n",
"print(IAM.__doc__)\n",
"\n",
"# uncomment a line below for details on either class\n",
"# IAMLines?? \n",
"# LineCNNTransformer??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below will train a model on 10% of the data for two epochs.\n",
"\n",
"It takes up to a few minutes to run on commodity hardware,\n",
"including data download and preprocessing.\n",
"As it's running, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"import torch\n",
"\n",
"\n",
"gpus = int(torch.cuda.is_available()) \n",
"\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the model trains, we're calculating lots of metrics --\n",
"loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n",
"and reporting them to the terminal.\n",
"\n",
"This is achieved by the built-in `.log` method\n",
"([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n",
"of the `LightningModule`,\n",
"and it is a very straightforward way to get basic information about your experiment as it's running\n",
"without leaving the context where you're running it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Learning to read\n",
"[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n",
"is something of a rite of passage for MLEs, but\n",
"let's consider what we can't see here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We're missing all metric values except the most recent --\n",
"we can see them as they stream in, but they're constantly overwritten.\n",
"We also can't associate them with timestamps, steps, or epochs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We also don't see any system metrics.\n",
"We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n",
"without launching a separate process.\n",
"And even if we do, those values will also not be saved and timestamped,\n",
"so we can't correlate them with other things during training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- As we continue to run experiments, changing code and opening new terminals,\n",
"even the information we have or could figure out now will disappear.\n",
"Say you spot a weird error message during training,\n",
"but your session ends and the stdout is gone,\n",
"so you don't know exactly what it was.\n",
"Can you recreate the error?\n",
"Which git branch and commit were you on?\n",
"Did you have any uncommitted changes? Which arguments did you pass?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Also, model checkpoints containing the parameter values have been saved to disk.\n",
"Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n",
"As we run more and more experiments,\n",
"we'll want to slice and dice them to see if,\n",
"say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to save and log all of this information, and more, in order to make our model training\n",
"[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n",
"in short, so that we can understand, make decisions about, and debug our model training\n",
"by looking at logs and source code, without having to recreate it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n",
"1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n",
"2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n",
"3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n",
"4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local Experiment Tracking with Tensorboard"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n",
"and it makes logging very easy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n",
"\n",
"It can also use a logger to push information elsewhere.\n",
"\n",
"By default, we use\n",
"[TensorBoard](https://www.tensorflow.org/tensorboard)\n",
"via the Lightning `TensorBoardLogger`,\n",
"which has been saving results to the local disk.\n",
"\n",
"Let's find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest experiment's directory\n",
"# by hand, you can just copy and paste it from the terminal\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n",
"filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n",
"latest_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!ls -lh {latest_log}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, we need to launch a TensorBoard server --\n",
"much like we need to launch a Jupyter server to use Jupyter notebooks.\n",
"\n",
"The cells below load an extension that lets you use TensorBoard inside of a notebook\n",
"the same way you'd use it from the command line, and then launch it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n",
"\n",
"port = 11717 # pick an open port on your machine\n",
"host = \"0.0.0.0\" # allow connections from the internet\n",
" # watch out! make sure you turn TensorBoard off\n",
"\n",
"%tensorboard --logdir {latest_log} --port {port} --host {host}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see some charts of metrics over time along with some charting controls.\n",
"\n",
"You can click around in this interface and explore it if you'd like,\n",
"but in the next section, we'll see that there are better tools for experiment management."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've run many experiments on this machine,\n",
"you can see all of their results by pointing TensorBoard\n",
"at the whole `lightning_logs` directory,\n",
"rather than just one experiment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For large numbers of experiments, the management experience is not great --\n",
"it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n",
"\n",
"It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n",
"which are important workflows as applications mature and teams grow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n",
"\n",
"If you run into any issues with the above cells failing to launch,\n",
"especially across iterations of this lab, run this cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorboard.manager\n",
"\n",
"# get the process IDs for all tensorboard instances\n",
"pids = [tb.pid for tb in tensorboard.manager.get_all()]\n",
"\n",
"done_with_tensorboard = False\n",
"\n",
"if done_with_tensorboard:\n",
" # kill processes\n",
" for pid in pids:\n",
" !kill {pid} 2> /dev/null\n",
" \n",
" # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n",
" !rm -rf {tensorboard.manager._get_info_dir()}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiment Management with Weights & Biases"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How do we manage experiments when we hit the limits of local TensorBoard?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is powerful and flexible and very scalable,\n",
"but running it requires engineering effort and babysitting --\n",
"you're running a database, writing data to it,\n",
"and layering a web application over it.\n",
"\n",
"This is a fairly common workflow for web developers,\n",
"but not so much for ML engineers.\n",
"\n",
"You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n",
"and it's as simple as running the command `tensorboard dev upload`\n",
"pointed at your logging directory.\n",
"\n",
"But there are strict limits to this free service:\n",
"1GB of tensor data and 1GB of binary data.\n",
"A single Text Recognizer model checkpoint is ~100MB,\n",
"and that's not particularly large for a useful model.\n",
"\n",
"Furthermore, all data is public,\n",
"so if you upload the inputs and outputs of your model,\n",
"anyone who finds the link can see them.\n",
"\n",
"Overall, tensorboard.dev works very well for certain academic and open projects\n",
"but not for industrial ML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To avoid that narrow permissions and limits issue,\n",
"you could use [git LFS](https://git-lfs.github.com/)\n",
"to track the binary data and tensor data,\n",
"which is more likely to be sensitive than metrics.\n",
"\n",
"The Hugging Face ecosystem uses TensorBoard and git LFS.\n",
"\n",
"It includes the Hugging Face Hub, a git server much like GitHub,\n",
"but designed first and foremost for collaboration on models and datasets,\n",
"rather than collaboration on code.\n",
"For example, the Hugging Face Hub\n",
"[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n",
"and officially has\n",
"[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n",
"avoiding the\n",
"[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n",
"that make using git LFS with GitHub expensive.\n",
"\n",
"However, we prefer to avoid mixing software version control and experiment management.\n",
"\n",
"First, using the Hub requires maintaining an additional git remote,\n",
"which is a hard ask for many engineering teams.\n",
"\n",
"Secondly, git-style versioning is an awkward fit for logging --\n",
"is it really sensible to create a new commit for each logging event while you're watching live?\n",
"\n",
"Instead, we prefer to use systems that solve experiment management with _databases_."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n",
"The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n",
"tool is [MLflow](https://github.com/mlflow/mlflow/)\n",
"and there are a number of\n",
"[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n",
"\n",
"These tools generally avoid any need to worry about hosting\n",
"(unless data governance rules require a self-hosted version).\n",
"\n",
"For a sampling of publicly-posted opinions on experiment management tools,\n",
"see these discussions from Reddit:\n",
"\n",
"- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n",
"- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n",
"\n",
"Among these tools, the FSDL recommendation is\n",
"[Weights & Biases](https://wandb.ai),\n",
"which we believe offers\n",
"- the best user experience, both in the Python SDKs and in the graphical interface\n",
"- the best integrations with other tools,\n",
"including\n",
"[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n",
"[Keras](https://docs.wandb.ai/guides/integrations/keras),\n",
"[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n",
"and even\n",
"[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n",
"and\n",
"- the best tools for collaboration.\n",
"\n",
"Below, we'll take care to point out which logging and management features\n",
"are available via generic interfaces in Lightning and which are W&B-specific."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"print(wandb.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding it to our experiment running code is extremely easy,\n",
"relative to the features we get, which is\n",
"one of the main selling points of W&B.\n",
"\n",
"We get most of our new experiment management features just by changing a single variable, `logger`, from\n",
"`TensorboardLogger` to `WandbLogger`\n",
"and adding two lines of code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll see what each of these lines does for us below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this logger is built into and maintained by PyTorch Lightning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"\n",
"WandbLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to complete the rest of this notebook,\n",
"you'll need a Weights & Biases account.\n",
"\n",
"As with GitHub the free tier, for personal, academic, and open source work,\n",
"is very generous.\n",
"\n",
"The Text Recognizer project will fit comfortably within the free tier.\n",
"\n",
"Run the cell below and follow the prompts to log in or create an account or go\n",
"[here](https://wandb.ai/signup)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the cell below to launch an experiment tracked with Weights & Biases.\n",
"\n",
"The experiment can take between 3 and 10 minutes to run.\n",
"In that time, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n",
" --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1\n",
" \n",
"last_expt = wandb.run\n",
"\n",
"wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see some new things in our output.\n",
"\n",
"For example, there's a note from `wandb` that the data is saved locally\n",
"and also synced to their servers.\n",
"\n",
"There's a link to a webpage for viewing the logged data and a name for our experiment --\n",
"something like `dandy-sunset-1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The local logging and cloud syncing happens with minimal impact on performance,\n",
"because `wandb` launches a separate process to listen for events and upload them.\n",
"\n",
"That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, head to the link in the notebook output\n",
"that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n",
"\n",
"There's no need to wait for training to finish.\n",
"\n",
"The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(last_expt.url)\n",
"IFrame(last_expt.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have landed on the run page\n",
"([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n",
"which collects up all of the information for a single experiment into a collection of tabs.\n",
"\n",
"We'll work through these tabs from top to bottom.\n",
"\n",
"Each header is also a link to the documentation for a tab."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n",
"This tab has an icon that looks like `(i)` or 🛈.\n",
"\n",
"The top section of this tab has high-level information about our run:\n",
"- Timing information, like start time and duration\n",
"- System hardware, hostname, and basic environment info\n",
"- Git repository link and state\n",
"\n",
"This information is collected and logged automatically.\n",
"\n",
"The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n",
"and summary metrics.\n",
"\n",
"Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n",
"\n",
"This tab has a line plot icon, something like 📈.\n",
"\n",
"It's also the default page you land on when looking at a W&B run.\n",
"\n",
"Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n",
"\n",
"Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n",
"\n",
"We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n",
"This tab has computer chip icon.\n",
"\n",
"It contains\n",
"- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n",
"- CPU metrics: memory usage, utilization, thread counts\n",
"- Disk and network I/O levels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n",
"This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n",
"\n",
"The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n",
"This tab has an icon that looks like a stylized command prompt, `>_`.\n",
"\n",
"It contains information that was printed to the stdout.\n",
"\n",
"This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n",
"\n",
"Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelSummary\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n",
"\n",
"For more on Lightning `Callback`s, see\n",
"[Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n",
"This tab has a stylized document icon, something like 📄.\n",
"\n",
"You can use this tab to view any files saved with the `wandb.save`.\n",
"\n",
"For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n",
"which we'll discuss shortly.\n",
"\n",
"But a few pieces of information automatically collected by W&B end up in this tab.\n",
"\n",
"Some highlights:\n",
" - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n",
" - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n",
"This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n",
"\n",
"This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n",
"\n",
"We store two kinds of binary files\n",
" - `run_table`s of model inputs and outputs\n",
" - `model` checkpoints\n",
"\n",
"We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interactive Tables of Logged Media"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Returning to the Charts tab,\n",
"notice that we have model inputs and outputs logged in structured tables\n",
"under the train, validation, and test sections.\n",
"\n",
"These tables are interactive as well\n",
"([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n",
"They support basic exploratory data analysis and are compatible with W&B's collaboration features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to charts in our run page, these tables also have their own pages inside the W&B web app."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n",
"table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n",
"\n",
"print(table_data_url)\n",
"IFrame(src=table_data_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Getting this to work requires more effort and more W&B-specific code\n",
"than the other features we've seen so far.\n",
"\n",
"We'll briefly explain the implementation here, for those who are interested.\n",
"\n",
"We use a custom Lightning `Callback`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n",
"\n",
"\n",
"ImageToTextTableLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n",
"\n",
"The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n",
"\n",
"This behavior is sensible for metrics, which are low overhead, but not so much for media,\n",
"where we'd rather subsample and avoid holding on to too much information.\n",
"\n",
"So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n",
"\n",
"The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.lit_models.base import BaseImageToTextLitModel\n",
"\n",
"BaseImageToTextLitModel.add_on_logged_batches??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Projects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Everything we've seen so far has been related to a single run or experiment.\n",
"\n",
"Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n",
"\n",
"We organize our runs into \"projects\" and view them on the W&B \"project page\" \n",
"([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n",
"\n",
"By default in the Lightning integration, the project name is determined based on directory information.\n",
"This default can be over-ridden in the code when creating a `WandbLogger`,\n",
"but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what the project page looks like for a longer-running project with lots of experiments.\n",
"\n",
"The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n",
"\n",
"print(project_url)\n",
"IFrame(src=project_url, width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n",
"\n",
"We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Artifacts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Beyond logging metrics and metadata from runs,\n",
"we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below pulls up all of the artifacts associated with the experiment we just ran."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Click on one of the `model` checkpoints -- the specific version doesn't matter.\n",
"\n",
"There are a number of tabs here.\n",
"\n",
"The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n",
"\n",
"The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n",
"which are added by default by the `WandbLogger`.\n",
"\n",
"The \"Files\" tab contains the actual file contents of the artifact.\n",
"\n",
"On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n",
"including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n",
"\n",
"You can click on these to explore the different versions and even directly compare them.\n",
"\n",
"If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Artifact storage is part of the W&B free tier.\n",
"\n",
"The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n",
"\n",
"The former is sufficient to store ~700 model checkpoints for the Text Recognizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can track your data storage and compare it to your limits at this URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n",
"\n",
"print(storage_tracker_url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Programmatic Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also programmatically access our data and metadata via the `wandb` API\n",
"([docs](https://docs.wandb.ai/guides/track/public-api-guide)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wb_api = wandb.Api()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run = wb_api.run(\"/\".join( # fetch a run given\n",
" [last_expt.entity, # the user or org it was logged to\n",
" last_expt.project, # the \"project\", usually one of several per repo/application\n",
" last_expt.id] # and a unique ID\n",
"))\n",
"\n",
"hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n",
"\n",
"hist.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hist.groupby(\"epoch\")[\"train/loss\"].mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this includes the artifacts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# which artifacts where created and logged?\n",
"artifacts = run.logged_artifacts()\n",
"\n",
"for artifact in artifacts:\n",
" print(f\"artifact of type {artifact.type}: {artifact.name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thanks to our `ImageToTextTableLogger`,\n",
"we can easily recreate training or validation data that came out of our `DataLoader`s,\n",
"which is normally ephemeral:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n",
"artifact_dir = Path(artifact.download(root=\"training/logs\"))\n",
"image_dir = artifact_dir / \"media\" / \"images\"\n",
"\n",
"images = [path for path in image_dir.iterdir()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"from IPython.display import Image\n",
"\n",
"Image(str(random.choice(images)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Advanced W&B API Usage: MLOps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the strengths of a well-instrumented experiment tracking system is that it allows\n",
"automatic relation of information:\n",
"what were the inputs when this model's gradient spiked?\n",
"Which models have been trained on this dataset,\n",
"and what was their performance?\n",
"\n",
"Having access and automation around this information is necessary for \"MLOps\",\n",
"which applies contemporary DevOps principles to ML projects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cells below pull down the training data\n",
"for the model currently running the FSDL Text Recognizer app.\n",
"\n",
"This is just intended as a demonstration of what's possible,\n",
"so don't worry about understanding every piece of this,\n",
"and feel free to skip past it.\n",
"\n",
"MLOps is still a nascent field, and these tools and workflows are likely to change.\n",
"\n",
"For example, just before the course launched, W&B released a\n",
"[Model Registry layer](https://docs.wandb.ai/guides/models)\n",
"on top of artifact logging that aims to improve the developer experience for these workflows."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start from the same project we looked at in the project view:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n",
"\n",
"text_recognizer_project "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we search it for the text recognizer model currently being used in production:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# collect all versions of the text-recognizer ever put into production by...\n",
"\n",
"for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n",
" if art_type.name == \"prod-ready\": # for the prod-ready type\n",
" # and grabbing the text-recognizer\n",
" production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n",
"\n",
"# and then get the one that's currently being tested in CI by...\n",
"for text_recognizer in production_text_recognizers:\n",
" if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n",
" in_prod_text_recognizer = text_recognizer\n",
"\n",
"# view its metadata at the url or in the notebook\n",
"in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n",
"\n",
"print(in_prod_text_recognizer_url)\n",
"IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From its metadata, we can get information about how it was \"staged\" to be put into production,\n",
"and in particular which model checkpoint was used:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staging_run = in_prod_text_recognizer.logged_by()\n",
"\n",
"training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n",
"training_ckpt.name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That checkpoint was logged by a training experiment, which is available as metadata.\n",
"\n",
"We can look at the training run for that model, either here in the notebook or at its URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"training_run = training_ckpt.logged_by()\n",
"print(training_run.url)\n",
"IFrame(src=training_run.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And from there, we can access logs and metadata about training,\n",
"confident that we are working with the model that is actually in production.\n",
"\n",
"For example, we can pull down the data we logged and analyze it locally."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_results = training_run.history(samples=10000)\n",
"training_results.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n",
"training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 10\n",
"training_results[\"validation/loss\"].dropna().iloc[10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The charts and webpages in Weights & Biases\n",
"are substantially more useful than ephemeral stdouts or raw logs on disk.\n",
"\n",
"If you're spun up on the project,\n",
"they accelerate debugging, exploration, and discovery.\n",
"\n",
"If not, they're not so much useful as they are overwhelming.\n",
"\n",
"We need to synthesize the raw logged data into information.\n",
"This helps us communicate our work with other stakeholders,\n",
"preserve knowledge and prevent repetition of work,\n",
"and surface insights faster.\n",
"\n",
"These workflows are supported by the W&B Reports feature\n",
"([docs here](https://docs.wandb.ai/guides/reports)),\n",
"which mix W&B charts and tables with explanatory markdown text and embeds.\n",
"\n",
"Below are some common report patterns and\n",
"use cases and examples of each."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of the examples are from the FSDL Text Recognizer project.\n",
"You can find more of them\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n",
"where we've organized them into a report!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dashboard Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dashboards are a structured subset of the output from one or more experiments,\n",
"designed for quickly surfacing issues or insights,\n",
"like an accuracy or performance regression\n",
"or a change in the data distribution.\n",
"\n",
"Use cases:\n",
"- show the basic state of ongoing experiment\n",
"- compare one experiment to another\n",
"- select the most important charts so you can spin back up into context on a project more quickly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n",
"\n",
"IFrame(src=dashboard_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pull Request Documentation Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In most software codebases,\n",
"pull requests are a key focal point\n",
"for units of work that combine\n",
"short-term communication and long-term information tracking.\n",
"\n",
"In ML codebases, it's more difficult to bring\n",
"sufficient information together to make PRs as useful.\n",
"At FSDL, we like to add documentary\n",
"reports with one or a small number of charts\n",
"that connect logged information in the experiment management system\n",
"to state in the version control software.\n",
"\n",
"Use cases:\n",
"- communication of results within a team, e.g. code review\n",
"- record-keeping that links pull request pages to raw logged info and makes it discoverable\n",
"- improving confidence in PR correctness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n",
"\n",
"IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Blog Post Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With sufficient effort, the logged data in the experiment management system\n",
"can be made clear enough to be consumed,\n",
"sufficiently contextualized to be useful outside the team, and\n",
"even beautiful.\n",
"\n",
"The result is a report that's closer to a blog post than a dashboard or internal document.\n",
"\n",
"Use cases:\n",
"- communication between teams or vertically in large organizations\n",
"- external technical communication for branding and recruiting\n",
"- attracting users or contributors\n",
"\n",
"Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n",
"[Boris Dayma](https://twitter.com/borisdayma)\n",
"and others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n",
"\n",
"IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperparameter Optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Many of our choices, like the depth of our network, the nonlinearities of our layers,\n",
"and the learning rate and other parameters of our optimizer, cannot be\n",
"([easily](https://arxiv.org/abs/1606.04474))\n",
"chosen by descent of the gradient of a loss function.\n",
"\n",
"But these parameters that impact the values of the parameters\n",
"we directly optimize with gradients, or _hyperparameters_,\n",
"can still be optimized,\n",
"essentially by trying options and selecting the values that worked best.\n",
"\n",
"In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n",
"\n",
"Expending more compute can squeeze small amounts of additional validation or test performance\n",
"that makes for impressive results on leaderboards but typically doesn't translate\n",
"into better user experience.\n",
"\n",
"In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n",
"built into your other tooling.\n",
"\n",
"Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n",
"([docs](https://docs.wandb.ai/guides/sweeps)).\n",
"\n",
"It also supports a number of more advanced tools, like\n",
"[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n",
"for early termination of poorly-performing runs.\n",
"\n",
"We can use the same training script and we don't need to run an optimization server.\n",
"\n",
"We just need to write a configuration yaml file\n",
"([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n",
"like the one below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile training/simple-overfit-sweep.yaml\n",
"# first we specify what we're sweeping\n",
"# we specify a program to run\n",
"program: training/run_experiment.py\n",
"# we optionally specify how to run it, including setting default arguments\n",
"command: \n",
" - ${env}\n",
" - ${interpreter}\n",
" - ${program}\n",
" - \"--wandb\"\n",
" - \"--overfit_batches\"\n",
" - \"1\"\n",
" - \"--log_every_n_steps\"\n",
" - \"25\"\n",
" - \"--max_epochs\"\n",
" - \"100\"\n",
" - \"--limit_test_batches\"\n",
" - \"0\"\n",
" - ${args} # these arguments come from the sweep parameters below\n",
"\n",
"# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n",
"method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n",
"metric:\n",
" name: train/loss\n",
" goal: minimize\n",
"parameters: \n",
" # LineCNN hyperparameters\n",
" window_width:\n",
" values: [8, 16, 32, 64]\n",
" window_stride:\n",
" values: [4, 8, 16, 32]\n",
" # Transformer hyperparameters\n",
" tf_layers:\n",
" values: [1, 2, 4, 8]\n",
" # we can also fix some values, just like we set default arguments\n",
" gpus:\n",
" value: 1\n",
" model_class:\n",
" value: LineCNNTransformer\n",
" data_class:\n",
" value: IAMLines\n",
" loss:\n",
" value: transformer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the config we launch a \"controller\":\n",
"a lightweight process that just decides what hyperparameters to try next\n",
"and coordinates the heavierweight training.\n",
"\n",
"This lives on the W&B servers, so there are no headaches about opening ports for communication,\n",
"cleaning up when it's done, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n",
"simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we can launch an \"agent\" to follow the orders of the controller:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# interrupt twice to terminate this cell if it's running too long,\n",
"# it can be over 15 minutes with some hyperparameters\n",
"\n",
"!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n",
"\n",
"If not provided, the agent will run forever for random or Bayesian sweeps\n",
"or until the sweep is terminated, which can be done from the W&B interface."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The agents make for a slick workflow for distributing sweeps across GPUs.\n",
"\n",
"We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n",
"which controls which GPUs are accessible by a process, to launch\n",
"parallel agents on separate GPUs on the same machine."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n",
"# open another terminal\n",
"CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n",
"# and so on\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We include optional exercises with the labs for learners who want to dive deeper on specific topics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟Contribute to a hyperparameter search."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n",
"\n",
"There are ~10,000,000 potential hyperparameter combinations,\n",
"and each takes 30 minutes to test,\n",
"so checking each possibility will take over 500 years of compute time.\n",
"Best get cracking then!\n",
"\n",
"Run the cell below to pull up a dashboard and print the URL where you can check on the current status."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_entity = \"fullstackdeeplearning\"\n",
"sweep_project = \"fsdl-line-recognizer-2022\"\n",
"sweep_id = \"e0eo43eu\"\n",
"sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n",
"\n",
"print(sweep_url)\n",
"IFrame(src=sweep_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also retrieve information about the sweep from the API,\n",
"including the hyperparameters being swept over."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hyperparams = sweep_info.config[\"parameters\"]\n",
"hyperparams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you'd like to contribute to this sweep,\n",
"run the cell below after changing the count to a number greater than 0.\n",
"\n",
"Each iteration runs for 30 minutes if it does not crash,\n",
"e.g. due to out-of-memory errors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"count = 0 # off by default, increase it to join in!\n",
"\n",
"if count:\n",
" !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Write some manual logging in `wandb`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the FSDL Text Recognizer codebase,\n",
"we almost exclusively log to W&B through Lightning,\n",
"rather than through the `wandb` Python SDK.\n",
"\n",
"If you're interested in learning how to use W&B directly, e.g. with another training framework,\n",
"try out this quick exercise that introduces the key players in the SDK."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n",
"\n",
"It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n",
"\n",
"Add W&B metric and artifact logging to this cell:\n",
"- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n",
"- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import os\n",
"import random\n",
"\n",
"import wandb\n",
"\n",
"\n",
"os.makedirs(\"logs\", exist_ok=True)\n",
"\n",
"project = \"trying-wandb\"\n",
"config = {\"steps\": 50}\n",
"\n",
"\n",
"with wandb.init(project=project, config=config) as run:\n",
" steps = wandb.config[\"steps\"]\n",
" \n",
" for ii in range(steps):\n",
" loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n",
" \n",
" with open(\"logs/hello.txt\", \"w\") as f:\n",
" f.write(\"hello from wandb, my dudes!\")\n",
" \n",
" run_id = run.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hello_run = wb_api.run(f\"{project}/{run_id}\")\n",
"\n",
"# check for logged loss data\n",
"if \"loss\" not in hello_run.history().keys():\n",
" print(\"loss not logged 🥲\")\n",
"else:\n",
" print(\"loss logged successfully 🥞\")\n",
" if len(hello_run.history()[\"loss\"]) != steps:\n",
" print(\"loss not logged on all steps 🥲\")\n",
" else:\n",
" print(\"loss logged on all steps 🥞\")\n",
"\n",
"artifacts = hello_run.logged_artifacts()\n",
"\n",
"# check for artifact with the right name\n",
"if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n",
" print(\"hello artifact not logged 🥲\")\n",
"else:\n",
" print(\"hello artifact logged successfully 🥞\")\n",
" # check for the file inside the artifacts\n",
" if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n",
" print(\"could not find hello.txt 🥲\")\n",
" else:\n",
" print(\"hello.txt logged successfully 🥞\")\n",
" \n",
" \n",
"hello_run"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you observe interesting phenomena during training,\n",
"from promising hyperparameter combos to software bugs to strange model behavior,\n",
"turn the charts into a W&B report and share it with the FSDL community or\n",
"[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n",
"with a link to them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# check the sweep_info.config above to see the model and data hyperparameters\n",
"# read through the --help output for all potential arguments\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n",
" --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n",
" --help # remove this line to run an experiment instead of printing help\n",
" \n",
"last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟🌟🌟 Add logging of tensor statistics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to logging model inputs and outputs as human-interpretable media,\n",
"it's also frequently useful to see information about their numerical values."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're interested in learning more about metric calculation and logging with Lightning,\n",
"use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n",
"to add tensor statistic logging to the `LineCNNTransformer`.\n",
"\n",
"`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n",
"\n",
"All three are useful, but start by adding just one."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n",
"- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n",
" - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n",
"- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n",
" - Base your code on the calculation and logging of the `val_cer` metric.\n",
" - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ",
"collapsed_sections": [],
"name": "lab04_experiments.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab06/notebooks/lab05_troubleshooting.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 05: Troubleshooting & Testing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n",
"- How to implement tests for ML training systems in particular\n",
"- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 5\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sThWeTtV6fL_"
},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xFP8lU4nSg1P"
},
"source": [
"# Linting Python and Shell Scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cXbdYfFlPhZ-"
},
"source": [
"### Automatically linting with `pre-commit`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ysqqb2GjvLrz"
},
"source": [
"We want keep our code clean and uniform across developers\n",
"and time.\n",
"\n",
"Applying the cleanliness checks and style rules should be\n",
"as painless and automatic as possible.\n",
"\n",
"For this purpose, we recommend bundling linting tools together\n",
"and enforcing them on all commits with\n",
"[`pre-commit`](https://pre-commit.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvqtZChKvLr0"
},
"source": [
"In addition to running on every commit,\n",
"`pre-commit` separates the model development environment from the environments\n",
"needed for the linting tools, preventing conflicts\n",
"and simplifying maintenance and onboarding."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0XuIuKOXhJl"
},
"source": [
"This cell runs `pre-commit`.\n",
"\n",
"The first time it is run on a machine, it will install the environments for all tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hltYGbpNvLr1"
},
"outputs": [],
"source": [
"!pre-commit run --all-files"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gLw08gIkvLr1"
},
"source": [
"The output lists all the checks that are run and whether they are passed.\n",
"\n",
"Notice there are a number of simple version-control hygiene practices included\n",
"that aren't even specific to Python, much less to machine learning.\n",
"\n",
"For example, several of the checks prevent accidental commits with private keys, large files, \n",
"leftover debugger statements, or merge conflict annotations in them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RHEEjb9kvLr1"
},
"source": [
"These linting actions are configured via\n",
"([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n",
"a YAML file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dgXa8BzrvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8HYc_WbTvLr2"
},
"source": [
"Most of the general cleanliness checks are from hooks built by `pre-commit`.\n",
"\n",
"See the comments and links in the `.pre-commit-config.yaml` for more:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K9rTgRqzvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep repos -A 15"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ptkO7aPvLr2"
},
"source": [
"Let's take a look at the section of the file\n",
"that applies most of our Python style enforcement with\n",
"[`flake8`](https://flake8.pycqa.org/en/latest/):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ALsRKfcevLr3",
"scrolled": true
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_Q0BwQUXbg6"
},
"source": [
"The majority of the style checking behavior we want comes from the\n",
"`additional_dependencies`, which are\n",
"[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n",
"that extend `flake8`'s list of lints.\n",
"\n",
"Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n",
"\n",
"We keep the configuration information for `flake8`\n",
"separate from that for `pre-commit`\n",
"in case we want to use additional tools with `flake8`,\n",
"e.g. if some developers want to integrate it directly into their editor,\n",
"and so that if we change away from `.pre-commit`\n",
"but keep `flake8` we don't have to\n",
"recreate our configuration in a different tool.\n",
"\n",
"As much as possible, codebases should strive for single sources of truth\n",
"and link back to those sources of truth with documentation or comments,\n",
"as in the last line above.\n",
"\n",
"Let's take a look at the contents of `flake8`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "doC_4WQwvLr3"
},
"outputs": [],
"source": [
"!cat .flake8"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Nq6HnyU0M47"
},
"source": [
"There's a lot here! We'll focus on the most important bits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U4PiB8CPvLr3"
},
"source": [
"Linting tools in Python generally work by emitting error codes\n",
"with one or more letters followed by three numbers.\n",
"The `select` argument picks which error codes we want to check for.\n",
"Error codes are matched by prefix,\n",
"so for example `B` matches `BTS101` and\n",
"`G1` matches `G102` and `G199` but not `ARG404`.\n",
"\n",
"Certain codes are `ignore`d in the default `flake8` style,\n",
"which is done via the `ignore` argument,\n",
"and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n",
"For example, we rely on `black` to do our formatting,\n",
"so we ignore some of `flake8`'s formatting codes.\n",
"\n",
"Together, these settings define our project's particular style.\n",
"\n",
"But not every file fits this style perfectly.\n",
"Most of the conventions in `black` and `flake8` come from the style-defining\n",
"[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n",
"which exhorts you to \"know when to be inconsistent\".\n",
"\n",
"To allow ourselves to be inconsistent when we know we should be,\n",
"`flake8` includes `per-file-ignores`,\n",
"which let us ignore specific warnings in specific files.\n",
"This is one of the \"escape valves\"\n",
"that makes style enforcement tolerable.\n",
"We can also `exclude` files in the `pre-commit` config itself.\n",
"\n",
"For details on selecting and ignoring,\n",
"see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n",
"\n",
"For definitions of the error codes from `flake8` itself,\n",
"see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n",
"Individual extensions list their added error codes in their documentation,\n",
"e.g. `darglint` does so\n",
"[here](https://github.com/terrencepreilly/darglint#error-codes)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NL0TpyPsvLr4"
},
"source": [
"The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n",
"\n",
"You can read more about each in their documentation:\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mFsZC0a7vLr4"
},
"source": [
"### Linting via a script and using `shellcheck`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RYjpuFwjXkJc"
},
"source": [
"To avoid needing to think about `pre-commit`\n",
"(was the command `pre-commit run` or `pre-commit check`?)\n",
"while developing locally,\n",
"we might put our linters into a shell script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mXlLFWmavLr4"
},
"outputs": [],
"source": [
"!cat tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PPxHpRIB3nbw"
},
"source": [
"These kinds of short and simple shell scripts are common in projects\n",
"of intermediate size.\n",
"\n",
"They are useful for adding automation and reducing friction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TMuPBpAi2qwl"
},
"source": [
"But these scripts are code,\n",
"and all code is susceptible to bugs and subject to concerns of style consistency."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SQRg3ZqXvLr4"
},
"source": [
"We can't check these scripts with tools that lint Python code,\n",
"so we include a shell script linting tool,\n",
"[`shellcheck`](https://www.shellcheck.net/),\n",
"in our `pre-commit`.\n",
"\n",
"More so than checking for correct style,\n",
"this tool checks for common bugs or surprising behaviors of shells,\n",
"which are unfortunately numerous."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zkfhE1srvLr4"
},
"outputs": [],
"source": [
"script_filename = \"tasks/lint.sh\"\n",
"!pre-commit run shellcheck --files {script_filename}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXU9TRrwvLr4"
},
"source": [
"That script has already been tested, so we don't see any errors.\n",
"\n",
"Try copying over a script you've written yourself or\n",
"even from a popular repo that you like\n",
"(by adding to the notebook directory or by making a cell\n",
"with `%%writefile` at the top)\n",
"and test it by changing the `script_filename`.\n",
"\n",
"You'd be surprised at the classes of subtle bugs possible in bash!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "81MhAL-TvLr5"
},
"source": [
"### Try \"unofficial bash strict mode\" for louder failures in scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hSwhs_zUvLr5"
},
"source": [
"Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n",
"[@redsymbol](https://twitter.com/redsymbol),\n",
"which appear at the top of the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o-j0vSxEvLr5"
},
"outputs": [],
"source": [
"!head -n 3 tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d2iJU5jlvLr5"
},
"source": [
"The core idea of strict mode is to fail more loudly.\n",
"This is a desirable behavior of scripts,\n",
"like the ones we're writing,\n",
"even though it's an undesirable behavior for an interactive shell --\n",
"it would be unpleasant to be logged out every time you hit an error.\n",
"\n",
"`set -u` means scripts fail if a variable's value is `u`nset,\n",
"i.e. not defined.\n",
"Otherwise bash is perfectly happy to allow you to reference undefined variables.\n",
"The result is just an empty string, which can lead to maddeningly weird behavior.\n",
"\n",
"`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n",
"rather than using the exit code of the last command.\n",
"Unix tools are perfectly happy to work on nonsense input,\n",
"like sorting error messages, instead of the filenames you meant to send.\n",
"\n",
"You can read more about these choices\n",
"[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n",
"and considerations for working with other non-conforming scripts in \"strict mode\"\n",
"and for handling resource teardown when scripts error out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s1XqsrU_XWWS"
},
"source": [
"# Testing ML Codebases"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPNzeq3NYF2W"
},
"source": [
"## Testing Python code with `pytests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zq5e_x6gc9Vu"
},
"source": [
"\n",
"ML codebases are Python first and foremost, so first let's get some Python tests going."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DC3GxYz6_R9"
},
"source": [
"At a basic level,\n",
"we can write functions that `assert`\n",
"that our code behaves as expected in\n",
"a given scenario and include it in the same module."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rvd-GNwv63W1"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.metrics import test_character_error_rate\n",
"\n",
"test_character_error_rate??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iVB2TsQS5BTq"
},
"source": [
"The standard tool for testing Python code is\n",
"[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n",
"\n",
"We can use it as a command-line tool in a variety of ways,\n",
"including to execute these kinds of tests.\n",
"\n",
"If passed a filename, `pytest` will look for\n",
"any classes that start with `Test` or\n",
"any functions that start with `test_` and run them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u8sQguyJvLr6",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest text_recognizer/lit_models/metrics.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "92tkBCllvLr6"
},
"source": [
"After the results of the tests (pass or fail) are returned,\n",
"you'll see a report of \"coverage\" from\n",
"[`codecov`](https://about.codecov.io/).\n",
"\n",
"This coverage report tells us which files and how many lines in those files\n",
"were at touched by the testing suite."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PllSUe0s5xvU"
},
"source": [
"We do not actually need to provide the names of files with tests in them to `pytest`\n",
"in order for it to run our tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4qOBHJnTZM9x"
},
"source": [
"By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n",
"\n",
"It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n",
"to separate these from the rest of your code\n",
"in a folder or folders named `tests`,\n",
"rather than scattering them around the repo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "acjsYTNSvLr6"
},
"outputs": [],
"source": [
"!ls text_recognizer/tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WZQQZUF0vLr6"
},
"source": [
"Let's take a look at a specific example:\n",
"the tests for some of our utilities around\n",
"custom PyTorch Lightning `Callback`s."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oS0xKv1evLr6"
},
"outputs": [],
"source": [
"from text_recognizer.tests import test_callback_utils\n",
"\n",
"\n",
"test_callback_utils.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lko8msn-vLr7"
},
"source": [
"Notice that we can easily import this as a module!\n",
"\n",
"That's another benefit of organizing tests into specialized files."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5A85FUNv75Fr"
},
"source": [
"The particular utility we're testing\n",
"here is designed to prevent crashes:\n",
"it checks for a particular type of error and turns it into a warning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jl4-DiVe76sw"
},
"outputs": [],
"source": [
"from text_recognizer.callbacks.util import check_and_warn\n",
"\n",
"check_and_warn??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B6E0MhduvLr7"
},
"source": [
"Error-handling code is a common cause of bugs,\n",
"a fact discovered\n",
"[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n",
"so it's very important to test it well!\n",
"\n",
"We start with a very basic test,\n",
"which does not touch anything\n",
"outside of the Python standard library,\n",
"even though this tool is intended to be used\n",
"with more complex features of third-party libraries,\n",
"like `wandb` and `tensorboard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xx5koQmJvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_simple??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MZe9-JVjvLr7"
},
"source": [
"Here, we are just testing the core logic.\n",
"This test won't catch many bugs,\n",
"but when it does fail, something has gone seriously wrong.\n",
"\n",
"These kinds of tests are important for resolving a bug:\n",
"we learn nearly as much from the tests that passed\n",
"as we did from the tests that failed.\n",
"If this test has failed, possibly along with others,\n",
"we can rule out an issue in one of the large external codebases\n",
"touched in the other tests, saving us lots of time in our troubleshooting.\n",
"\n",
"The reasoning for the test is explained in the docstrings, \n",
"which are close to the code.\n",
"\n",
"Your test suite should be as welcoming\n",
"as the rest of your codebase!\n",
"The people reading it, for example yourself in six months, \n",
"are likely upset and in need of some kindness.\n",
"\n",
"More practically, we want keep our time to resolve errors as short as possible,\n",
"and five minutes to write a good docstring now\n",
"can save five minutes during an outage, when minutes really matter."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Om9k-uXhvLr7"
},
"source": [
"That basic test is a start, but it's not enough by itself.\n",
"There's a specific error case that triggered the addition of this code.\n",
"\n",
"So we test that it's handled as expected."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fjbsb5FvvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_tblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGAIZTUjvLr7"
},
"source": [
"That test can fail if the libraries change around our code,\n",
"i.e. if the `TensorBoardLogger` gets a `log_table` method.\n",
"\n",
"We want to be careful when making assumptions\n",
"about other people's software,\n",
"especially for fast-moving libraries like Lightning.\n",
"If we test that those assumptions hold willy-nilly,\n",
"we'll end up with tests that fail because of\n",
"harmless changes in our dependencies.\n",
"\n",
"Tests that require a ton of maintenance and updating\n",
"without leading to code improvements soak up\n",
"more engineering time than they save\n",
"and cause distrust in the testing suite.\n",
"\n",
"We include this test because `TensorBoardLogger` getting\n",
"a `log_table` method will _also_ change the behavior of our code\n",
"in a breaking way, and we want to catch that before it breaks\n",
"a model training job."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jsy95KAvvLr7"
},
"source": [
"Adding error handling can also accidentally kill the \"happy path\"\n",
"by raising an error incorrectly.\n",
"\n",
"So we explicitly test the _absence of an error_,\n",
"not just its presence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LRlIOkjmvLr8"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_wandblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osiqpLynvLr8"
},
"source": [
"There are more tests we could build, e.g. manipulating classes and testing the behavior,\n",
"testing more classes that might be targeted by `check_and_warn`, or\n",
"asserting that warnings are raised to the command line.\n",
"\n",
"But these three basic tests are likely to catch most changes that would break our code here,\n",
"and they're a lot easier to write than the others.\n",
"\n",
"If this utility starts to get more usage and become a critical path for lots of features, we can always add more!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dm285JE5vLr8"
},
"source": [
"## Interleaving testing and documentation with `doctests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHWQvgA8vLr8"
},
"source": [
"One function of tests is to build user/reader confidence in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wrhiJBXFvLr8"
},
"source": [
"One function of documentation is to build user/reader knowledge in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1vu12LDhvLr8"
},
"source": [
"These functions are related. Let's put them together:\n",
"put code in a docstring and test that code.\n",
"\n",
"This feature is part of the\n",
"Python standard library via the\n",
"[`doctest` module](https://docs.python.org/3/library/doctest.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rmfIOwXd-Qt7"
},
"source": [
"Here's an example from our `torch` utilities.\n",
"\n",
"The `first_appearance` function can be used to\n",
"e.g. quickly look for stop tokens,\n",
"giving the length of each sequence."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZzURGcD9vLr8"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.util import first_appearance\n",
"\n",
"\n",
"first_appearance??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0VtYcJ1WvLr8"
},
"source": [
"Notice that in the \"Examples\" section,\n",
"there's a short block of code formatted as a\n",
"Python interpreter session,\n",
"complete with outputs.\n",
"\n",
"We can copy and paste that code and\n",
"check that we get the right outputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dj4lNOxJvLr9"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y9AWHFoIvLr9"
},
"source": [
"We can run the test with `pytest` by passing a command line argument,\n",
"`--doctest-modules`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JMaAxv5ovLr9"
},
"outputs": [],
"source": [
"!pytest --doctest-modules text_recognizer/lit_models/util.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6-2_aOUfvLr9"
},
"source": [
"With the\n",
"[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n",
"running `doctest`s happens automatically\n",
"when `pytest` is invoked."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "my_keokPvLr9"
},
"source": [
"## Basic tests for data code"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qj3Bq_j2_A8o"
},
"source": [
"ML code can be hard to test\n",
"since it involes very heavy artifacts, like models and data,\n",
"and very expensive jobs, like training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DT5OmgrQvLr9"
},
"source": [
"For testing our data-handling code in the FSDL codebase,\n",
"we mostly just use `assert`s,\n",
"which throw errors when behavior differs from expectation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bdzn5g4TvLr9"
},
"outputs": [],
"source": [
"!grep \"assert\" -r text_recognizer/data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2aTlfu4_vLr-"
},
"source": [
"This isn't great practice,\n",
"especially as a codebase grows,\n",
"because we can't easily know when these are executed\n",
"or incorporate them into\n",
"testing automation and coverage analysis tools."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IaMTdmbZ_mkW"
},
"source": [
"So it's preferable to collect up these assertions of simple data properties\n",
"into tests that are run like our other tests.\n",
"\n",
"The test below checks whether any data is leaking\n",
"between training, validation, and testing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qx7cxiDdvLr-"
},
"outputs": [],
"source": [
"from text_recognizer.tests.test_iam import test_iam_data_splits\n",
"\n",
"\n",
"test_iam_data_splits??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16TJwhd1vLr-"
},
"source": [
"Notice that we were able to load the test into the notebook\n",
"because it is in a module,\n",
"and so we can run it here as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mArITFkYvLr-"
},
"outputs": [],
"source": [
"test_iam_data_splits()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E4F2uaclvLr-"
},
"source": [
"But we're checking something pretty simple here,\n",
"so the new code in each test is just a single line.\n",
"\n",
"What if we wanted to test more complex properties,\n",
"like comparing rows or calculating statistics?\n",
"\n",
"We'll end up writing more complex code that might itself have subtle bugs,\n",
"requiring tests for our tests and suffering from\n",
"\"tester's regress\".\n",
"\n",
"This is the phenomenon,\n",
"named by analogy with\n",
"[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n",
"in sociology of science,\n",
"where the validity of our tests is itself\n",
"up for dispute only resolvable by testing the tests,\n",
"but those tests are themselves possibly invalid."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nUGT06gdvLr-"
},
"source": [
"We cut this Gordian knot by using\n",
"a library or framework that is well-tested.\n",
"\n",
"We recommend checking out\n",
"[`great_expectations`](https://docs.greatexpectations.io/docs/)\n",
"if you're looking for a high-quality data testing tool."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dQ5vNsq3vLr-"
},
"source": [
"Especially with data, some tests are particularly \"heavy\" --\n",
"they take a long time,\n",
"and we might want to run them\n",
"on different machines\n",
"and on a different schedule\n",
"than our other tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xephcb0LvLr-"
},
"source": [
"For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n",
"\n",
"We can't just use a cached version of the data,\n",
"since that won't actually execute the code!\n",
"\n",
"This test will take\n",
"as long to run\n",
"and consume as many resources as\n",
"a full download of the data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YSN4w2EqvLr-"
},
"source": [
"`pytest` allows the separation of tests\n",
"into suites with `mark`s,\n",
"which \"tag\" tests with names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V0rScrcXvLr_",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest --markers | head -n 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lr5Ca7B0vLr_"
},
"source": [
"We can choose to run tests with a given mark\n",
"or to skip tests with a given mark, \n",
"among other basic logical operations around combining and filtering marks,\n",
"with `-m`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xmw-Eb1ZvLr_"
},
"outputs": [],
"source": [
"!wandb login # one test requires wandb authentication\n",
"\n",
"!pytest -m \"not data and not slow\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5LuERxOXX_UJ"
},
"source": [
"## Testing training with memorization tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AnWLN4lRvLsA"
},
"source": [
"Training is the process by which we convert inert data into executable models,\n",
"so it is dependent on both.\n",
"\n",
"We decouple checking whether the script has a critical bug\n",
"from whether the data or model code is broken\n",
"by testing on some basic \"fake data\",\n",
"based on a utility from `torchvision`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k4NIc3uWvLsA"
},
"outputs": [],
"source": [
"from text_recognizer.data import FakeImageData\n",
"\n",
"\n",
"FakeImageData.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deN0swwlvLsA"
},
"source": [
"We then test on the actual data with a smaller version of the real model.\n",
"\n",
"We use the Lightning `--fast_dev_run` feature,\n",
"which sets the number of training, validation, and test batches to `1`.\n",
"\n",
"We use a smaller version so that this test can run in just a few minutes\n",
"on a CPU without acceleration.\n",
"\n",
"That allows us to run our tests in environments without GPUs,\n",
"which saves on costs for executing tests.\n",
"\n",
"Here's the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z4J0_uD9vLsA"
},
"outputs": [],
"source": [
"!cat training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-7u9zS1vLsA",
"scrolled": false
},
"outputs": [],
"source": [
"! ./training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UTzfo11KClV3"
},
"source": [
"The above tests don't actaully check\n",
"whether any learning occurs,\n",
"they just check\n",
"whether training runs mechanically,\n",
"without any errors.\n",
"\n",
"We also need a\n",
"[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n",
"for learning.\n",
"For that we recommending checking whether\n",
"the model can learn the right\n",
"outputs for a single batch --\n",
"to \"memorize\" the outputs for\n",
"a particular input.\n",
"\n",
"This memorization test won't\n",
"catch every bug or issue in training,\n",
"which is notoriously difficult,\n",
"but it will flag\n",
"some of the most serious issues."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DVSp3aAvLsA"
},
"source": [
"The script below runs a memorization test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2DFVVrxpvLsA"
},
"source": [
"It takes up to two arguments:\n",
"a `MAX`imum number of `EPOCHS` to run for and\n",
"a `CRITERION` value of the loss to test against.\n",
"\n",
"The test passes if the loss is lower than the `CRITERION` value\n",
"after the `MAX`imum number of `EPOCHS` has passed."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oEhJH0e5vLsB"
},
"source": [
"The important line in this script is the one that invokes our training script,\n",
"`training/run_experiment.py`.\n",
"\n",
"The arguments to `run_experiment` have been tuned for maximum possible speed:\n",
"turning off regularization, shrinking the model,\n",
"and skipping parts of Lightning that we don't want to test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T-fFs1xEvLsB"
},
"outputs": [],
"source": [
"!cat training/tests/test_memorize_iam.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X-47tUA_YNGe"
},
"source": [
"If you'd like to see what a memorization run looks like,\n",
"flip the `running_memorization` flag to `True`\n",
"and watch the results stream in to W&B.\n",
"\n",
"The cell should run in about ten minutes on a commodity GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GwTEsZwKvLsB"
},
"outputs": [],
"source": [
"%%time\n",
"running_memorization = False\n",
"\n",
"if running_memorization:\n",
" max_epochs = 1000\n",
" loss_criterion = 0.05\n",
" !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Troubleshooting model speed with the PyTorch Profiler"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpbN-Om2Drf-"
},
"source": [
"Testing code is only half the story here:\n",
"we also need to fix the issues that our tests flag.\n",
"This is the process of troubleshooting.\n",
"\n",
"In this lab,\n",
"we'll focus on troubleshooting model performance issues:\n",
"what do to when your model runs too slowly."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NZzwELPXvLsD"
},
"source": [
"Troubleshooting deep neural networks for speed is challenging.\n",
"\n",
"There are at least three different common approaches,\n",
"each with an increasing level of skill required:\n",
"\n",
"1. Follow best practices advice from others\n",
"([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n",
"[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n",
"2. Take code that runs slowly and use empirical observations to iteratively improve it.\n",
"3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n",
"\n",
"For the full stack deep learning engineer,\n",
"the final level is typically out of reach,\n",
"unless you're specializing in the model performance\n",
"part of the stack in particular.\n",
"\n",
"So we recommend reaching the middle level,\n",
"and this segment of the lab walks through the\n",
"tools that make this easier."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_yp87UrFZ8M"
},
"source": [
"Because neural network training involves GPU acceleration,\n",
"generic Python profiling tools like\n",
"[`py-spy`](https://github.com/benfred/py-spy)\n",
"won't work, and\n",
"we'll need tools specialized for tracing and profiling DNN training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yspsYVFGEyZm"
},
"source": [
"In general, these tools are for observing what happens while your code is executing:\n",
"_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n",
"\n",
"Because they help us observe the execution in detail,\n",
"they will also help us understand just what is going on during\n",
"a PyTorch training step in greater detail."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YqXq2hKuvLsE"
},
"source": [
"To support profiling and tracing,\n",
"we've added a new argument to `training/run_experiment.py`, `--profile`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z_GMMViWvLsE"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZldoksHPvLsE"
},
"source": [
"As with experiment management, this relies mostly on features of PyTorch Lightning,\n",
"which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n",
"and we just add a few lines of customization:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F2iJ0_A6vLsE"
},
"outputs": [],
"source": [
"!cat training/run_experiment.py | grep args.profile -A 5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aw3ppgndvLsE"
},
"source": [
"For more on profiling with Lightning, see the\n",
"[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCAmNW3QEtcD"
},
"source": [
"The cell below runs an epoch of training with tracing and profiling turned on\n",
"and then saves the results locally and to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t4o3ylDgr46F",
"scrolled": false
},
"outputs": [],
"source": [
"import glob\n",
"\n",
"import torch\n",
"import wandb\n",
"\n",
"from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n",
"\n",
"\n",
"# make it easier to separate these from training runs\n",
"%env WANDB_JOB_TYPE=profile\n",
"\n",
"batch_size = 16\n",
"num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n",
"gpus = 1 # must be run with accelerator\n",
"\n",
"%run training/run_experiment.py --wandb --profile \\\n",
" --max_epochs=1 \\\n",
" --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n",
" --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n",
" --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n",
"\n",
"latest_expt = wandb.run\n",
"\n",
"try: # add execution trace to logged and versioned binaries\n",
" folder = wandb.run.dir\n",
" trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n",
" trace_file = glob.glob(trace_matcher)[0]\n",
" trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n",
" trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n",
" wandb.log_artifact(trace_at)\n",
"except IndexError:\n",
" print(\"trace not found\")\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ePTkS3EqO5tN"
},
"source": [
"We get out a table of statistics in the terminal,\n",
"courtesy of Lightning.\n",
"\n",
"Each row lists an operation\n",
"and and provides information,\n",
"described in the column headers,\n",
"about the time spent on that operation\n",
"across all the training steps we profiled.\n",
"\n",
"With practice, some useful information can be read out from this table,\n",
"but it's better to start from both a less detailed view,\n",
"in the TensorBoard dashboard,\n",
"and a more detailed view,\n",
"using the Chrome Trace viewer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TzV62f3c7-Bi"
},
"source": [
"## High-level statistics from the PyTorch Profiler in TensorBoard"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mNPKXkYw8NWd"
},
"source": [
"Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CbItwuT88eAV"
},
"outputs": [],
"source": [
"your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n",
"\n",
"print(your_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jE_LooMYHFpF"
},
"source": [
"If at any point you run into issues,\n",
"like the description not matching what you observe,\n",
"check out one of our example runs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "za2zybSwIo5C"
},
"outputs": [],
"source": [
"example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n",
"print(example_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xlrhl1n4HYU6"
},
"source": [
"Once the TensorBoard session has loaded up,\n",
"we are dropped into the Overview\n",
"(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n",
"for an example).\n",
"\n",
"In the top center, we see the **GPU Summary** for our system.\n",
"\n",
"In addition to the name of our GPU,\n",
"there are a few configuration details and top-level statistics.\n",
"They are (tersely) documented\n",
"[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MmBhUDgDLhd1"
},
"source": [
"- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n",
"this is effectively a coarse \"version number\" for your GPU hardware.\n",
"It indexes which features are available,\n",
"with more advanced features being available only at higher compute capabilities.\n",
"It does not directly index the speed or memory of the GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "voUgT6zuLyi0"
},
"source": [
"- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yl-IndtXE4b4"
},
"source": [
"- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n",
"for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n",
"Tensor Cores.\n",
"If you're running on an older GPU without Tensor Cores,\n",
"you should consider upgrading.\n",
"If you're running a more recent GPU but not seeing Tensor Core usage,\n",
"you should switch to single precision floating point numbers,\n",
"which Tensor Cores are specialized on."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XxcUf0bBNXy_"
},
"source": [
"- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n",
"at a lower level than just whether something is running at all,\n",
"as in utilization.\n",
"Unlike utilization, reaching 100% is not generally feasible\n",
"and sometimes not desirable.\n",
"Increasing these numbers requires expertise in\n",
"CUDA programming, so we'll target utilization instead."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A88pQn4YMMKc"
},
"source": [
"- **Execution Summary**: This table and pie chart indicates\n",
"how much time within a profiled step\n",
"was spent in each category.\n",
"The value for \"kernel\" execution here\n",
"is equal to the GPU utilization,\n",
"and we want that number to be as close to 100%\n",
"as possible.\n",
"This summary helps us know which\n",
"other operations are taking time,\n",
"like memory being copied between CPU and GPU (`memcpy`)\n",
"or `DataLoader`s executing on the CPU,\n",
"so we can decide where the bottleneck is."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6qjW1RlTQRPv"
},
"source": [
"At the very bottom, you'll find a\n",
"**Performance Recommendation**\n",
"tab that sometimes suggests specific methods for improving performance.\n",
"\n",
"If this tab makes suggestions, you should certainly take them!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pWY5AhrcRQmJ"
},
"source": [
"For more on using the profiler in TensorBoard,\n",
"including some of the other, more detailed views\n",
"available view the \"Views\" dropdown menu, see\n",
"[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQwrPY_H77H8"
},
"source": [
"## Going deeper with the Chrome Trace Viewer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yhwo7fslvLsH"
},
"source": [
"So far, we've seen summary-level information about our training steps\n",
"in the table from Lightning and in the TensorBoard Overview.\n",
"These give aggregate statistics about the computations that occurred,\n",
"but understanding how to interpret those statistics\n",
"and use them to speed up our networks\n",
"requires understanding just what is\n",
"happening in our training step.\n",
"\n",
"Fundamentally,\n",
"all computations are processes that unfold in time.\n",
"\n",
"If we want to really understand our training step,\n",
"we need to display it that way:\n",
"what operations were occurring,\n",
"on both the CPU and GPU,\n",
"at each moment in time during the training step.\n",
"\n",
"This information on timing is collected in the trace.\n",
"One of the best tools for viewing the trace over time\n",
"is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wUkZItxYc20A"
},
"source": [
"Let's tour the trace we just logged\n",
"with an aim to really understanding just\n",
"what is happening when we call\n",
"`training_step`\n",
"and by extension `.forward`, `.backward`, and `optimizer.step`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9w9F2UA7Qctg"
},
"source": [
"The Chrome Trace Viewer is built into W&B,\n",
"so we can view our traces in their interface.\n",
"\n",
"The cell below embeds the trace inside the notebook,\n",
"but you may wish to open it separately,\n",
"with the \"Open page\" button or by navigating to the URL,\n",
"so that you can interact with it\n",
"as you read the description below.\n",
"Display directly on W&B is also a bit less temperamental\n",
"than display on W&B inside a notebook.\n",
"\n",
"Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n",
"so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n",
"It also can interact poorly with browser extensions (e.g. ad blockers),\n",
"so you may need to deactivate them temporarily in order to see it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OMUs4aby6Rfd"
},
"outputs": [],
"source": [
"trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n",
"trace_url = trace_files_url + \"training_step.pt.trace.json\"\n",
"\n",
"example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n",
"\n",
"print(trace_url)\n",
"IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qNVpGeQtQjMG"
},
"source": [
"> **Heads up!** We're about to do a tour of the\n",
"> precise details of the tracing information logged\n",
"> during the execution of the training code.\n",
"> The only way to learn how to troubleshoot model performance\n",
"> empirically is to look at the details,\n",
"> but the details depend on the precise machine being used\n",
"> -- GPU and CPU and RAM.\n",
"> That means even within Colab,\n",
"> these details change from session to session.\n",
"> So if you don't observe a phenomenon or feature\n",
"> described in the tour below, check out\n",
"> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n",
"> on W&B while reading through the next section of the lab,\n",
"> and return to your trace once you understand the trace viewer better at the end.\n",
"> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n",
"> can sometimes be a bit janky."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kXMcBhnCgdN_"
},
"source": [
"This trace reveals, in nanosecond-level detail,\n",
"what's going on inside of a `training_step`\n",
"on both the GPU and the CPU.\n",
"\n",
"Time is on the horizontal axis.\n",
"Colored bars represent method calls,\n",
"and the methods called by a method are placed underneath it vertically,\n",
"a visualization known as an\n",
"[icicle chart](https://www.brendangregg.com/flamegraphs.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "67BsNzDfVIeg"
},
"source": [
"Let's orient ourselves with some gross features:\n",
"the forwards pass,\n",
"GPU kernel execution,\n",
"the backwards pass,\n",
"and the optimizer step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBEFgtRCKqrh"
},
"source": [
"### The forwards pass"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nYhiWesVMjK"
},
"source": [
"Type in `resnet` to the search bar in the top-right.\n",
"\n",
"This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n",
"\n",
"It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n",
"\n",
"You can click the arrows next to that tile to partially collapse these blocks.\n",
"\n",
"Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n",
"It should be at roughly the same height.\n",
"\n",
"Clear the search bar so that the trace is in color.\n",
"Zoom in on the area of the forwards pass\n",
"using the \"zoom\" tool in the floating toolbar,\n",
"so you can see more detail.\n",
"The zoom tool is indicated by a two-headed arrow\n",
"pointing into and out of the screen.\n",
"\n",
"Switch to the \"drag\" tool,\n",
"represented by a four-headed arrow.\n",
"Click-and-hold to use this tool to focus\n",
"on different parts of the timeline\n",
"and click on the individual colored boxes\n",
"to see details about a particular method call.\n",
"\n",
"As we go down in the icicle chart,\n",
"we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n",
"to much more precise `cudnn` and `cuda` operations\n",
"(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n",
"\n",
"`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n",
"is the tensor math library in PyTorch\n",
"that links to specific backends like `cudnn`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fq181ybIvLsH"
},
"source": [
"### GPU kernel execution"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbkWp5aKvLsH"
},
"source": [
"Towards the bottom, you should see a section labeled \"GPU\".\n",
"The label appears on the far left.\n",
"\n",
"Within it, you'll see one or more \"`stream`s\".\n",
"These are units of work on a GPU,\n",
"akin loosely to threads on the CPU.\n",
"\n",
"When there are colored bars in this area,\n",
"the GPU is doing work of some kind.\n",
"The fraction of this bar that is filled in with color\n",
"is the same as the \"GPU Utilization %\" we've seen previously.\n",
"So the first thing to visually assess\n",
"in a trace view of PyTorch code\n",
"is what fraction of this area is filled with color.\n",
"\n",
"In CUDA, work is queued up to be\n",
"placed into streams and completed, on the GPU,\n",
"in a distributed and asynchronous manner.\n",
"\n",
"The selection of which work to do\n",
"is happening on the CPU,\n",
"and that's what we were looking at above.\n",
"\n",
"The CPU and the GPU have to work together to coordinate\n",
"this work.\n",
"\n",
"Type `cuda` into the search bar and you'll see these coordination operations happening:\n",
"`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n",
"\n",
"Running the same PyTorch model\n",
"with the same high level operations like `Conv2d` in different versions of PyTorch,\n",
"on different GPUs, and even on tensors of different sizes will result\n",
"in different choices of concrete kernel operation,\n",
"e.g. different matrix multiplication algorithms.\n",
"\n",
"Type `sync` into the search bar and you'll see places where either work on the GPU\n",
"or work on the CPU needs to await synchronization,\n",
"e.g. copying data from the CPU to the GPU\n",
"or the CPU waiting to decide what to do next\n",
"on the basis of the contents of a tensor.\n",
"\n",
"If you see a \"sync\" block above an area\n",
"where the stream on the GPU is empty,\n",
"you've got a performance bottleneck due to synchronization\n",
"between the CPU and GPU.\n",
"\n",
"To resolve the bottleneck,\n",
"head up the icicle chart until you reach the recognizable\n",
"PyTorch modules and operations.\n",
"Find where they are called in your PyTorch module.\n",
"That's a good place to review your code to understand why the synchronization is happening\n",
"and removing it if it's not necessary."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XeMPbu_jvLsI"
},
"source": [
"### The backwards pass\n",
"\n",
"Type in `backward` into the search bar.\n",
"\n",
"This will highlight components of our backwards pass.\n",
"\n",
"If you read it from left to right,\n",
"you'll see that it begins by calculating the loss\n",
"(`NllLoss2DBackward` in the search bar if you can't find it)\n",
"and ends by doing a `ConvolutionBackward`,\n",
"the first layer of the ResNet.\n",
"It is, indeed, backwards.\n",
"\n",
"Like the forwards pass,\n",
"the backwards pass also involves the CPU\n",
"telling the GPU which kernels to run.\n",
"It's typically run in a separate\n",
"thread from the forwards pass,\n",
"so you'll see it separated out from the forwards pass\n",
"in the trace viewer.\n",
"\n",
"Generally, there's no need to specifically optimize the backwards pass --\n",
"removing bottlenecks in the forwards pass results in a fast backwards pass.\n",
"\n",
"One reason why is that these two passes are just\n",
"\"transposes\" of one another,\n",
"so they share a lot of properties,\n",
"and bottlenecks in one become bottlenecks in the other.\n",
"We can choose to optimize either one of the two.\n",
"But the forwards pass is under our direct control,\n",
"so it's easier for us to reason about.\n",
"\n",
"Another reason is that the forwards pass is more likely to have bottlenecks.\n",
"The forwards pass is a dynamic process,\n",
"with each line of Python adding more to the compute graph.\n",
"Backwards passes, on the other hand, use a static compute graph,\n",
"the one just defined by the forwards pass,\n",
"so more optimizations are possible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gWiDw0vCvLsI"
},
"source": [
"### The optimizer step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ndfkzEdnvLsI"
},
"source": [
"Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n",
"\n",
"As with the two passes,\n",
"we are still using the CPU\n",
"to launch kernels on the GPU.\n",
"But now the CPU is looping,\n",
"in Python, over the parameters\n",
"and applying the ADAM updates rules to each.\n",
"\n",
"We now know enough to see that\n",
"this is not great for our GPU utilization:\n",
"there are many areas of gray\n",
"in between the colored bars\n",
"in the GPU stream in this area.\n",
"\n",
"In the time it takes CUDA to multiply\n",
"thousands of numbers,\n",
"Python has not yet finished cleaning up\n",
"after its request for that multiplication.\n",
"\n",
"As of writing in August 2022,\n",
"more efficient optimizers are not a stable part of PyTorch (v1.12), but\n",
"[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n",
"and stable implementations outside of PyTorch.\n",
"The standard implementations are in\n",
"[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n",
"not to be confused with the\n",
"[Apex Optimizers Project](https://www.apexoptimizers.com/),\n",
"which is a collection of fitness-themed cheetah NFTs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WX0jxeafvLsI"
},
"source": [
"## Take-aways for PyTorch performance bottleneck troubleshooting"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CugD-bK2vLsI"
},
"source": [
"Our goal here was to learn some basic principles and tools for bottlenecking\n",
"the most common issues and the lowest-hanging fruit in PyTorch code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SwHwJkVMHYGA"
},
"source": [
"\n",
"Here's an overview in terms of a \"host\",\n",
"generally the CPU,\n",
"and a \"device\", here the GPU.\n",
"\n",
"- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n",
"- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n",
" - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n",
"- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n",
"stores information about the contents of tensors,\n",
"not just their metadata."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gntx28p9cBP5"
},
"source": [
"Towards that goal, we viewed the trace to get an understanding of\n",
"what's going on inside a PyTorch training step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AKvZGPnkeXvq"
},
"source": [
"Here's what we've means in terms of troubleshooting bottlenecks.\n",
"\n",
"We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n",
"before the previous kernel finishes.\n",
"\n",
"Ideally, the CPU is actually getting far _ahead_ of execution\n",
"on the GPU.\n",
"If the CPU makes it all the way through the backwards pass before the GPU is done,\n",
"that's great!\n",
"The GPU(s) are the expensive part,\n",
"and it's easy to use multiprocessing so that\n",
"the CPU has other things to do.\n",
"\n",
"This helps explain at least one common piece of advice:\n",
"the larger our batches are,\n",
"the more work the GPU has to do for the same work done by the CPU,\n",
"and so the better our utilization will be."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMztpa-TccH4"
},
"source": [
"We operationalize our desire to never be waiting on the CPU with a simple metric:\n",
"**100% GPU utilization**, meaning a kernel is running at all times.\n",
"\n",
"This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n",
"\n",
"You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7kYBygfScR6z"
},
"source": [
"Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n",
"1. **The CPU is too weak**.\n",
"Because so much of the discussion around DNN performance is about GPUs,\n",
"it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n",
"_Resolution_:\n",
"Use nice CPUs, like\n",
"[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n",
"2. **Too much Python during the `training_step`**.\n",
"Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n",
"that takes longer than running an entire layer.\n",
"_Resolution_:\n",
"Look for low utilization areas of the trace\n",
"and check what's happening on the CPU at that time\n",
"and carefully review the Python code being executed.\n",
"3. **Unnecessary Host/Device synchronization**.\n",
"If one of your operations depends on the values in a tensor,\n",
"like `if xs.mean() >= 0`,\n",
"you'll induce a synchronization between\n",
"the host and the device and possibly lead\n",
"to an expensive and slow copy of data.\n",
"_Resolution_:\n",
"Replace these operations as much as possible\n",
"with purely array-based calculations.\n",
"4. **Bottlenecking on the DataLoader**.\n",
"In addition to coordinating the work on the GPU,\n",
"CPUs often perform heavy data operations,\n",
"including communication over the network\n",
"and writing to/reading from disk.\n",
"These are generally done in parallel to the forwards\n",
"and backwards passes,\n",
"but if they don't finish before that happens,\n",
"they will become the bottleneck.\n",
"_Resolution_:\n",
"Get better hardware for compute,\n",
"memory, and network.\n",
"For software solutions, the answer \n",
"is a bit more complex and application-dependent.\n",
"For generic tips, see\n",
"[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n",
"in the PyTorch forums.\n",
"For techniques in computer vision, see\n",
"[the FFCV library](https://github.com/libffcv/ffcv)\n",
"and for techniques in NLP, see e.g.\n",
"[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n",
"and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i2WYS8bQvLsJ"
},
"source": [
"### Further steps in making DNNs go brrrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T0wW2_lRKfY1"
},
"source": [
"It's important to note that utilization\n",
"is just an easily measured metric\n",
"that can reveal common bottlenecks.\n",
"Having high utilization does not automatically mean\n",
"that your performance is fully optimized.\n",
"\n",
"For example,\n",
"synchronization events between GPUs\n",
"are counted as kernels,\n",
"so a deadlock during distributed training\n",
"can show up as 100% utilization,\n",
"despite literally no useful work occurring.\n",
"\n",
"Just switching to \n",
"double precision floats, `--precision=64`,\n",
"will generally lead to much higher utilization.\n",
"The GPU operations take longer\n",
"for roughly the same amount of CPU effort,\n",
"but the added precision brings no benefit.\n",
"\n",
"In particular, it doesn't make for models\n",
"that perform better on our correctness metrics,\n",
"like loss and accuracy.\n",
"\n",
"Another useful yardstick to add\n",
"to utilization is examples per second,\n",
"which incorporates how quickly the model is processing data examples\n",
"and calculating gradients.\n",
"\n",
"But really,\n",
"the gold star is _decrease in loss per second_.\n",
"This metric connects model design choices\n",
"and hyperparameters with purely engineering concerns,\n",
"so it disrespects abstraction barriers\n",
"and doesn't generally lead to actionable recommendations,\n",
"but it is, in the end, the real goal:\n",
"make the loss go down faster so we get better models sooner."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EFzPsplfdo_o"
},
"source": [
"For PyTorch internals abstractly,\n",
"see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"For more on performance considerations in PyTorch,\n",
"see [Horace He's blog post](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yq6-S6TC38AY"
},
"source": [
"### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n",
"\n",
"One of the most important features for making\n",
"PyTorch run quickly is the\n",
"`MultiprocessingDataLoader`,\n",
"which executes batching of data in a separate process\n",
"from the forwards and backwards passes.\n",
"\n",
"By default in PyTorch,\n",
"this feature is actually turned off,\n",
"via the `DataLoader` argument `num_workers`\n",
"having a default value of `0`,\n",
"but we set the `DEFAULT_NUM_WORKERS`\n",
"to a value based on the number of CPUs\n",
"available on the system running the code.\n",
"\n",
"Re-run the profiling cell,\n",
"but set `num_workers` to `0`\n",
"to turn off multiprocessing.\n",
"\n",
"Compare and contrast the two traces,\n",
"both for total runtime\n",
"(see the time axis at the top of the trace)\n",
"and for utilization.\n",
"\n",
"If you're unable to run the profiles,\n",
"see the results\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n",
"which juxtaposes two traces,\n",
"with in-process dataloading on the left and\n",
"multiprocessing dataloading on the right."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T2i_a5eVeIoA"
},
"source": [
"The file below incorrectly implements and then incorrectly tests\n",
"a simple PyTorch utility for adding five to every entry of a tensor\n",
"and then calculating the sum.\n",
"\n",
"Even worse, it does it with horrible style!\n",
"\n",
"The cells below apply our linting checks\n",
"(after automatically fixing the formatting)\n",
"and run the test.\n",
"\n",
"Fix all of the lints,\n",
"implement the function correctly,\n",
"and then implement some basic tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wSon2fB5VVM_"
},
"source": [
"- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aYiRvU4HA84t"
},
"outputs": [],
"source": [
"%%writefile training/fixme.py\n",
"import torch\n",
"from training import run_experiment\n",
"from numpy import *\n",
"import random\n",
"from pathlib import Path\n",
"\n",
"\n",
"\n",
"\n",
"def add_five_and_sum(tensor):\n",
" # this function is not implemented right,\n",
" # but it's supposed to add five to all tensor entries and sum them up\n",
" return 1\n",
"\n",
"def test_add_five_and_sum():\n",
" # and this test isn't right either! plus this isn't exactly a docstring\n",
" all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n",
" all_fives = 5 * all_ones\n",
" assert False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EXJpmvuzT1w0"
},
"outputs": [],
"source": [
"!pre-commit run black --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SRO-oJfdUrcQ"
},
"outputs": [],
"source": [
"!cat training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jM8NHxVbSEQD"
},
"outputs": [],
"source": [
"!pre-commit run --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kj0VMBSndtkc"
},
"outputs": [],
"source": [
"!pytest training/fixme.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab05_troubleshooting.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab06/notebooks/lab06_data.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 06: Data Annotation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How the `IAM` handwriting dataset is structured on disk and how it is processed into an ML-friendly format\n",
"- How to setup a [Label Studio](https://labelstud.io/) data annotation server\n",
"- Just how messy data really is"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 6\n",
"\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpvaHz9TEGwV"
},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gsXpeXi2EGwV"
},
"outputs": [],
"source": [
"from IPython.display import IFrame\n",
"\n",
"\n",
"IFrame(src=\"https://fsdl.me/2022-lab-06-video-embed\", width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XTkKzEMNR8XZ"
},
"source": [
"# `IAMParagraphs`: From annotated data to a PyTorch `Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3mQLbjuiwZuj"
},
"source": [
"We've used the `text_recognizer.data` submodule\n",
"and its `LightningDataModule`s -- `IAMLines` and `IAMParagraphs`\n",
"for lines and paragraphs of handwritten text\n",
"from the\n",
"[IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database).\n",
"\n",
"These classes convert data from a database-friendly format\n",
"designed for storage and transfer into the\n",
"format our DNNs expect:\n",
"PyTorch `Tensor`s.\n",
"\n",
"In this section,\n",
"we'll walk through that process in detail.\n",
"\n",
"In the following section,\n",
"we'll see how data\n",
"goes from signals measured in the world\n",
"to the format we consume here."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "499c23a6"
},
"source": [
"## Dataset structure on disk"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a3438d2e"
},
"source": [
"We begin by downloading the raw data to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "18900eec"
},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM\n",
"\n",
"iam = IAM()\n",
"iam.prepare_data()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a332f359"
},
"source": [
"The `IAM` dataset is downloaded as zip file\n",
"and then unzipped:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d6c44266"
},
"outputs": [],
"source": [
"from text_recognizer.metadata.iam import DL_DATA_DIRNAME\n",
"\n",
"\n",
"iam_dir = DL_DATA_DIRNAME\n",
"!ls {iam_dir}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8463c2d1"
},
"source": [
"The unzipped dataset is not simple a flat directory of files.\n",
"\n",
"Instead, there are a number of subfolders,\n",
"each of which contains a particular type of data or metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "536924f7"
},
"outputs": [],
"source": [
"iamdb = iam_dir / \"iamdb\"\n",
"\n",
"!du -h {iamdb}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b745a594"
},
"source": [
"For example, the `task` folder contains metadata about canonical dataset splits:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "84c21f75"
},
"outputs": [],
"source": [
"!find {iamdb / \"task\"} | grep \"\\\\.txt$\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mEb0Pdm4vIHe"
},
"source": [
"We find the images of handwritten text in the `forms` folder.\n",
"\n",
"An individual \"datapoint\" in `IAM` is a \"form\",\n",
"because the humans whose hands wrote the text were prompted to write on \"forms\",\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "945d5e3a"
},
"outputs": [],
"source": [
"from IPython.display import Image\n",
"\n",
"\n",
"form_fn, = !find {iamdb}/forms | grep \".jpg$\" | sort | head -n 1\n",
"\n",
"print(form_fn)\n",
"Image(filename=form_fn, width=\"360\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b9e9e384"
},
"source": [
"Meanwhile, the `xml` files contain the data annotations,\n",
"written out as structured text:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6add5c5a"
},
"outputs": [],
"source": [
"xml_fn, = !find {iamdb}/xml | grep \"\\.xml$\" | sort | head -n 1\n",
"\n",
"!cat {xml_fn} | grep -A 100 \"handwritten-part\" | grep \"
", "", " and ", *tokens, " and ", *tokens, ""]
self.end_index = self.inverse_mapping["",
""]
self.end_token = inverse_mapping[""]
self.end_token = inverse_mapping[""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab07/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab07/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 03: Transformers and Paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The fundamental reasons why the Transformer is such\n",
"a powerful and popular architecture\n",
"- Core intuitions for the behavior of Transformer architectures\n",
"- How to use a convolutional encoder and a Transformer decoder to recognize\n",
"entire paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 3\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Transformers?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal in building a text recognizer is to take a two-dimensional image\n",
"and convert it into a one-dimensional sequence of characters\n",
"from some alphabet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional neural networks,\n",
"discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n",
"are great at encoding images,\n",
"taking them from their raw pixel values\n",
"to a more semantically meaningful numerical representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But how do we go from that to a sequence of letters?\n",
"And what's especially tricky:\n",
"the number of letters in an image is separable from its size.\n",
"A screenshot of this document has a much higher density of letters\n",
"than a close-up photograph of a piece of paper.\n",
"How do we get a _variable-length_ sequence of letters,\n",
"where the length need have nothing to do with the size of the input tensor?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n",
"they were\n",
"[originally introduced](https://arxiv.org/abs/1706.03762)\n",
"for transforming one sequence into another,\n",
"as in machine translation.\n",
"This makes them a natural fit for processing language.\n",
"\n",
"But they have also found success in other domains --\n",
"at the time of this writing, large transformers\n",
"dominate the\n",
"[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n",
"that has become a de facto standard for comparing models\n",
"and are finding\n",
"[application in reinforcement learning](https://arxiv.org/abs/2106.01345)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we will use a Transformer as a key component of our final architecture:\n",
"we will encode our input images with a CNN\n",
"and then read them out into a text sequence with a Transformer.\n",
"\n",
"Before trying out this new model,\n",
"let's first get an understanding of why the Transformer architecture\n",
"has become so popular by walking through its history\n",
"and then get some intuition for how it works\n",
"by looking at some\n",
"[recent work](https://transformer-circuits.pub/)\n",
"on explaining the behavior of both toy models and state-of-the-art language models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kmKqjbvd-Mj3"
},
"source": [
"## Why not convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRqkUMdM-OxU"
},
"source": [
"In the ancient beforetimes (i.e. 2016),\n",
"the best models for natural language processing were all\n",
"_recurrent_ neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional networks were also occasionally used,\n",
"but they suffered from a serious issue:\n",
"their architectural biases don't fit text.\n",
"\n",
"First, _translation equivariance_ no longer holds.\n",
"The beginning of a piece of text is often quite different from the middle,\n",
"so the absolute position matters.\n",
"\n",
"Second, _locality_ is not as important in language.\n",
"The name of a character that hasn't appeared in thousands of pages\n",
"can become salient when someone asks, \"Whatever happened to\n",
"[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n",
"\n",
"Consider interpreting a piece of text like the Python code below:\n",
"```python\n",
"def do(arg1, arg2, arg3):\n",
" a = arg1 + arg2\n",
" b = arg3[:3]\n",
" c = a * b\n",
" return c\n",
"\n",
"print(do(1, 1, \"ayy lmao\"))\n",
"```\n",
"\n",
"After a `(` we expect a `)`,\n",
"but possibly very long afterwards,\n",
"[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n",
"and similarly we expect a `]` at some point after a `[`.\n",
"\n",
"For translation variance, consider\n",
"that we interpret `*` not by\n",
"comparing it to its neighbors\n",
"but by looking at `a` and `b`.\n",
"We mix knowledge learned through experience\n",
"with new facts learned while reading --\n",
"also known as _in-context learning_.\n",
"\n",
"In a longer text,\n",
"[e.g. the one you are reading now](./lab03_transformers.ipynb),\n",
"the translation variance of text is clearer.\n",
"Every lab notebook begins with the same header,\n",
"setting up the environment,\n",
"but that header never appears elsewhere in the notebook.\n",
"Later positions need to be processed in terms of the previous entries.\n",
"\n",
"Unlike an image, we cannot simply rotate or translate our \"camera\"\n",
"and get a new valid text.\n",
"[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n",
"that can be read without regard to position."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The field of formal language theory,\n",
"which has deep mutual influence with computer science,\n",
"gives one way of explaining the issues with convolutional networks:\n",
"they can only understand languages with _finite contexts_,\n",
"where all the information can be found within a finite window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The immediate solution, drawing from the connections to computer science, is\n",
"[recursion](https://www.google.com/search?q=recursion).\n",
"A network whose output on the final entry of the sequence is a recursive function\n",
"of all the previous entries can build up knowledge\n",
"as it reads the sequence and treat early entries quite differently than it does late ones."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa6cbTlImkEh"
},
"source": [
"In pseudo-code, such a _recurrent neural network_ module might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lKtBoPnglPrW"
},
"source": [
"```python\n",
"def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n",
" next_inputs = input_module(xs[-1])\n",
" next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n",
" return output_module(next_inputs, next_hiddens)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbJPSMnEm516"
},
"source": [
"If you've had formal computer science training,\n",
"then you may be familiar with the power of recursion,\n",
"e.g. the\n",
"[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n",
"that gave its name to the now much better-known\n",
"[startup incubator](https://www.ycombinator.com/).\n",
"\n",
"The particular form of recursion used by\n",
"recurrent neural networks implements a\n",
"[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n",
"\n",
"> If you've know a lot of computer science,\n",
"you might be concerned by this connection.\n",
"What about other\n",
"[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n",
"Where are the neural network architectures for differentiable\n",
"[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n",
"Check out Graph Neural Networks,\n",
"[which implement dynamic programming](https://arxiv.org/abs/2203.15544)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "63mMTbEBpVuE"
},
"source": [
"Recurrent networks are able to achieve\n",
"[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n",
"\n",
"There are many popular recurrent architectures,\n",
"from the beefy and classic\n",
"[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n",
"and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n",
"([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n",
"all of which have roughly similar capabilities but\n",
"[some of which are easier to train](https://arxiv.org/abs/1611.09913)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwQHVTIslOku"
},
"source": [
"In the same sense that MLPs can model \"any\" feedforward function,\n",
"in principle even basic RNNs\n",
"[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n",
"\n",
"In particular they can model any\n",
"[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n",
"which is a formal way of saying that they can in principle\n",
"do anything a computer is capable of doing.\n",
"\n",
"The question is then..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3J8EoGN3pu7P"
},
"source": [
"## Why aren't we all using RNNs?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDwNWaevpt_3"
},
"source": [
"The guarantees that MLPs can model any function\n",
"or that RNNs can model Turing machines\n",
"provide decent intuition but are not directly practically useful.\n",
"Among other reasons, they don't guarantee learnability --\n",
"that starting from random parameters we can find the parameters\n",
"that implement a given function.\n",
"The\n",
"[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n",
"than would seem from basic theoretical and empirical analysis.\n",
"\n",
"One way of understanding capacity to model language is\n",
"[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n",
"In this model of formal languages,\n",
"Turing machines sit at the top\n",
"([practically speaking](https://arxiv.org/abs/math/0209332)).\n",
"\n",
"With better mathematical models,\n",
"RNNs and LSTMs can be shown to be\n",
"[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n",
"with RNNs looking more like\n",
"[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n",
"and LSTMs coming in\n",
"[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n",
"\n",
"More controversially:\n",
"the Chomsky hierarchy is great for understanding syntax and grammar,\n",
"which makes it great for building parsers\n",
"and working with formal languages,\n",
"but the goal in _natural_ language processing is to understand _natural_ language.\n",
"Most humans' natural language is far from strictly grammatical,\n",
"but that doesn't mean it is nonsense.\n",
"\n",
"And to really \"understand\" language means\n",
"to understand its semantic content, which is fuzzy.\n",
"The most important thing for handling the fuzzy semantic content\n",
"of language is not whether you can recall\n",
"[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n",
"but whether you can model probabilistic relationships between concepts\n",
"in addition to grammar and syntax."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These both leave theoretical room for improvement over current recurrent\n",
"language and sequence models.\n",
"\n",
"But the real cause of the rise of Transformers is that..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dsu1ebvAp-3Z"
},
"source": [
"## Transformers are designed to train fast at scale on contemporary hardware."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4abU5adsPGs"
},
"source": [
"The Transformer architecture has several important features,\n",
"discussed below,\n",
"but one of the most important reasons why it is successful\n",
"is because it can be more easily trained at scale.\n",
"\n",
"This scalability is the focus of the discussion in the paper\n",
"that introduced the architecture,\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"and\n",
"[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n",
"\n",
"The recursion in RNNs is inherently sequential:\n",
"the dependence on the outputs from earlier in the sequence\n",
"means computations within an example cannot be parallelized.\n",
"\n",
"So RNNs must batch across examples to scale,\n",
"but as sequence length grows this hits memorybandwidth limits.\n",
"Serving up large batches quickly with good randomness guarantees\n",
"is also hard to optimize,\n",
"especially in distributed settings.\n",
"\n",
"The Transformer architecture,\n",
"on the other hand,\n",
"can be readily parallelized within a single example sequence,\n",
"in addition to parallelization across batches.\n",
"This can lead to massive performance gains for a fixed scale,\n",
"which means larger, higher capacity models\n",
"can be trained on larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Mzk2haFC_G1"
},
"source": [
"How does the architecture achieve this parallelizability?\n",
"\n",
"Let's start with the architecture diagram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u59eu4snLQfp"
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n",
"\n",
"display.Image(url=base_url + \"/aiayn-figure-1.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ez-XEQ7M0UlR"
},
"source": [
"> To head off a bit of confusion\n",
" in case you've worked with Transformer architectures before:\n",
" the original \"Transformer\" is an encoder/decoder architecture.\n",
" Many LLMs, like GPT models, are decoder only,\n",
" because this has turned out to scale well,\n",
" and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n",
" it's all text anyways.\n",
" We, however, will be using them across modalities,\n",
" so we need an explicit encoder,\n",
" as above. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ok4ksBi4vp89"
},
"source": [
"First focusing on the encoder (left):\n",
"the encoding at a given position is a function of all previous inputs.\n",
"But it is not a function of the previous _encodings_:\n",
"we produce the encodings \"all at once\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RPN7C-_OqzHP"
},
"source": [
"The decoder (right) does use previous \"outputs\" as its inputs,\n",
"but those outputs are not the vectors of layer activations\n",
"(aka embeddings)\n",
"that are produced by the network.\n",
"They are instead the processed outputs,\n",
"after a `softmax` and an `argmax`.\n",
"\n",
"We could obtain these outputs by processing the embeddings,\n",
"much like in a recurrent architecture.\n",
"In fact, that is one way that Transformers are run.\n",
"It's what happens in the `.forward` method\n",
"of the model we'll be training for character recognition:\n",
"`ResnetTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L5_2WMmtDnJn"
},
"source": [
"Let's look at that forward method\n",
"and connect it to the diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FR5pk4kEyCGg"
},
"outputs": [],
"source": [
"from text_recognizer.models import ResnetTransformer\n",
"\n",
"\n",
"ResnetTransformer.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-J5UFDoPzPbq"
},
"source": [
"`.encode` happens first -- that's the left side of diagram.\n",
"\n",
"The encoder can in principle be anything\n",
"that produces a sequence of fixed-length vectors,\n",
"but here it's\n",
"[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n",
"\n",
"Then we start iterating over the sequence\n",
"in the `for` loop.\n",
"\n",
"Focus on the first few lines of code.\n",
"We apply `.decode` (right side of diagram)\n",
"to the outputs so far.\n",
"\n",
"Once we have a new `output`, we apply `.argmax`\n",
"to turn the logits into a concrete prediction of\n",
"a particular token.\n",
"\n",
"This is added as the last output token\n",
"and then the loop happens again."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTcy8-rV1dHr"
},
"source": [
"Run this way, our model looks very much like a recurrent architecture:\n",
"we call the model on its own outputs\n",
"to generate the next value.\n",
"These types of models are also referred to as\n",
"[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n",
"because we predict (as we do in _regression_)\n",
"the next value based on our own (_auto_) output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But Transformers are designed to be _trained_ more scalably than RNNs,\n",
"not necessarily to _run inference_ more scalably,\n",
"and it's actually not the case that our model's `.forward` is called during training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCxMSAWmEKBt"
},
"source": [
"Let's look at what happens during training\n",
"by checking the `training_step`\n",
"of the `LightningModule`\n",
"we use to train our Transformer models,\n",
"the `TransformerLitModel`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o7q8N7P2w4H"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import TransformerLitModel\n",
"\n",
"TransformerLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VgNNOjvzC4y"
},
"source": [
"Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tz-6NGPR4dUr"
},
"source": [
"Let's look at `.teacher_forward`,\n",
"and in particular its type signature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ILc2oWET4i2Z"
},
"outputs": [],
"source": [
"TransformerLitModel.teacher_forward??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lf32lpgrDb__"
},
"source": [
"This is known as \"teacher forcing\".\n",
"The \"teacher\" signal is \"forcing\"\n",
"the model to behave as though\n",
"it got the answer right.\n",
"\n",
"[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n",
"It's more effective here\n",
"because the right teaching signal\n",
"for our network is the target data,\n",
"which we have access to during training,\n",
"whereas in an RNN the best teaching signal\n",
"would be the target embedding vector,\n",
"which we do not know.\n",
"\n",
"During inference, when we don't have access to the ground truth,\n",
"we revert to the autoregressive `.forward` method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This \"trick\" allows Transformer architectures to readily scale\n",
"up models to the parameter counts\n",
"[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAjqpJm9uUuU"
},
"source": [
"## Is there more to Transformers more than just a training trick?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kWCYXeHv7Qc9"
},
"source": [
"[Very](https://arxiv.org/abs/2005.14165),\n",
"[very](https://arxiv.org/abs/1909.08053),\n",
"[very](https://arxiv.org/abs/2205.01068)\n",
"large Transformer models have powered the most recent wave of exciting results in ML, like\n",
"[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n",
"\n",
"They are also the first machine learning models to have come anywhere close to\n",
"deserving the term _artificial intelligence_ --\n",
"a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is surprising because the models and their training procedure are\n",
"(relatively speaking)\n",
"pretty _simple_,\n",
"even if it doesn't feel that way on first pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The basic Transformer architecture is just a bunch of\n",
"dense matrix multiplications and non-linearities --\n",
"it's perhaps simpler than a convolutional architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And advances since the introduction of Transformers in 2017\n",
"have not in the main been made by\n",
"creating more sophisticated model architectures\n",
"but by increasing the scale of the base architecture,\n",
"or if anything making it simpler, as in\n",
"[GPT-type models](https://arxiv.org/abs/2005.14165),\n",
"which drop the encoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V1HQS9ey8GMc"
},
"source": [
"These models are also trained on very simple tasks:\n",
"most LLMs are just trying to predict the next element in the sequence,\n",
"given the previous elements --\n",
"a task simple enough that Claude Shannon,\n",
"father of information theory, was\n",
"[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n",
"\n",
"These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n",
"e.g. by scraping the web."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They are also trained in a simple fashion:\n",
"first-order stochastic optimizers, like SGD or an\n",
"[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n",
"intended for the most basic of optimization problems,\n",
"that scale more readily than the second-order optimizers\n",
"that dominate other areas of optimization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kz9HPDoy7OAl"
},
"source": [
"This is\n",
"[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n",
"of work in ML:\n",
"simple, even seemingly wasteful,\n",
"architectures that scale well and are robust\n",
"to implementation details\n",
"eventually outstrip more clever but\n",
"also more finicky approaches that are harder to scale.\n",
"This lesson has led some to declare that\n",
"[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n",
"in machine learning, and perhaps even in artificial intelligence."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdN9o2Y771YZ"
},
"source": [
"> That is not to say that because the algorithms are relatively simple,\n",
" training a model at this scale is _easy_ --\n",
" [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n",
" [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n",
" [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n",
" But choosing the simplest algorithm at every step makes solving the scaling problem feasible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baVGf6gKFOvs"
},
"source": [
"The importance of scale is the key lesson from the Transformer architecture,\n",
"far more than any theoretical considerations\n",
"or any of the implementation details.\n",
"\n",
"That said, these large Transformer models are capable of\n",
"impressive behaviors and understanding how they achieve them\n",
"is of intellectual interest.\n",
"Furthermore, like any architecture,\n",
"there are common failure modes,\n",
"of the model and of the modelers who use them,\n",
"that need to be taken into account."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"Below, we'll cover two key intuitions about Transformers:\n",
"Transformers are _residual_, like ResNets,\n",
"and they compose _low rank_ sequence transformations.\n",
"Together, this means they act somewhat like a computer,\n",
"reading from and writing to a \"tape\" or memory\n",
"with a sequence of simple instructions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"We'll also cover a surprising implementation detail:\n",
"despite being commonly used for sequence modeling,\n",
"by default the architecture is _position insensitive_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uni0VTCr9lev"
},
"source": [
"### Intuition #1: Transformers are highly residual."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0MoBt-JLJz-d"
},
"source": [
"> The discussion of these inuitions summarizes the discussion in\n",
"[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n",
"from\n",
"[Anthropic](https://www.anthropic.com/),\n",
"an AI safety and research company.\n",
"The figures below are from that blog post.\n",
"It is the spiritual successor to the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"covered in\n",
"[Lab 02b](https://lab02b-colab).\n",
"If you want to truly understand Transformers,\n",
"we highly recommend you check it out,\n",
"including the\n",
"[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UUbNVvM5Ferm"
},
"source": [
"It's easy to see that ResNets are residual --\n",
"it's in the name, after all.\n",
"\n",
"But Transformers are,\n",
"in some sense,\n",
"even more closely tied to residual computation\n",
"than are ResNets:\n",
"ResNets and related architectures include downsampling,\n",
"so there is not a direct path from inputs to outputs.\n",
"\n",
"In Transformers, the exact same shape is maintained\n",
"from the moment tokens are embedded,\n",
"through dozens or hundreds of intermediate layers,\n",
"and until they are \"unembedded\" into class logits.\n",
"The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n",
"\n",
"The resiudal stream is easy to see with a change of perspective.\n",
"Instead of the usual architecture diagram above,\n",
"which emphasizes the layers acting on the tensors,\n",
"consider this alternative view,\n",
"which emphasizes the tensors as they pass through the layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRMlVguKKW6y"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-residual-view.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9K3N7ilVkB3"
},
"source": [
"For definitions of variables and terms, see the\n",
"[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arvciE-kKd_L"
},
"source": [
"Note that this is a _decoder-only_ Transformer architecture --\n",
"so it should be compared with the right-hand side of the original architecture diagram above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvrRMd_RKp_G"
},
"source": [
"Notice that outputs of the attention blocks \n",
"and of the MLP layers are\n",
"added to their inputs, as in a ResNet.\n",
"These operations are represented as \"Add & Norm\" layers in the classical diagram;\n",
"normalization is ignored here for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8n_iT-FFAbK"
},
"source": [
"This total commitment to residual operations\n",
"means the size of the embeddings\n",
"(referred to as the \"model dimension\" or the \"embedding dimension\",\n",
"here and below `d_model`)\n",
"stays the same throughout the entire network.\n",
"\n",
"That means, for example,\n",
"that the output of each layer can be used as input to the \"unembedding\" layer\n",
"that produces logits.\n",
"We can read out the computations of intermediate layers\n",
"just by passing them through the unembedding layer\n",
"and examining the logit tensor.\n",
"See\n",
"[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n",
"for detailed experiments and interactive notebooks.\n",
"\n",
"In short, we observe a sort of \"progressive refinement\"\n",
"of the next-token prediction\n",
"as the embeddings proceed, depthwise, through the network."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ovh_3YgY9z2h"
},
"source": [
"### Intuition #2 Transformer heads learn low rank transformations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpNmozlnOdPC"
},
"source": [
"In the original paper and in\n",
"most presentations of Transformers,\n",
"the attention layer is written like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PA7me8gNP5LE"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pseudo-typed PyTorch (based loosely on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n",
"that looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oeict_6wGJgD"
},
"source": [
"```python\n",
"def classic_attention(\n",
" Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" K: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" return torch.softmax(Q @ K.T) @ V\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8pewU90DSuOR"
},
"source": [
"This is effectively exactly\n",
"how it is written\n",
"in PyTorch,\n",
"apart from implementation details\n",
"(look for `bmm` for the matrix multiplications and a `softmax` call):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrgTpKFvOhwc"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"F._scaled_dot_product_attention??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebDXZ0tlSe7g"
},
"source": [
"But the best way to write an operation so that a computer can execute it quickly\n",
"is not necessarily the best way to write it so that a human can understand it --\n",
"otherwise we'd all be coding in assembly.\n",
"\n",
"And this is a strange way to write it --\n",
"you'll notice that what we normally think of\n",
"as the \"inputs\" to the layer are not shown.\n",
"\n",
"We can instead write out the attention layer\n",
"as a function of the inputs $x$.\n",
"We write it for a single \"attention head\".\n",
"Each attention layer includes a number of heads\n",
"that read and write from the residual stream\n",
"simultaneously and independently.\n",
"We also add the output layer weights $W_O$\n",
"and we get:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LuFNR67tQpsf"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVnBjjfOLwxP"
},
"source": [
"or, in pseudo-typed PyTorch:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmpOm-HfGaNz"
},
"source": [
"```python\n",
"def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n",
" key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n",
" key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n",
" # maps queries of residual stream to keys from residual stream, independent of position\n",
"\n",
" value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n",
" output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n",
" value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n",
" # transformation applied to each token, regardless of position\n",
"\n",
" attention_logits = x.T @ key_query_circuit @ x\n",
" attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n",
" # maps positions to positions, often very sparse\n",
"\n",
" value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n",
"\n",
" return attention_map @ value_output # transformed tokens filtered by attention map\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dC0eqxZ6UAGT"
},
"source": [
"Consider the `key_query_circuit`\n",
"and `value_output_circuit`\n",
"matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n",
"\n",
"The key/query dimension, `d_head`\n",
"is small relative to the model's dimension, `d_model`,\n",
"so $W_{QK}$ and $W_{OV}$ are very low rank,\n",
"[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n",
"that they factorize into two matrices,\n",
"one with a smaller number of rows\n",
"and another with a smaller number of columns.\n",
"That number is called the _rank_.\n",
"\n",
"When computing, these matrices are better represented via their components,\n",
"rather than computed directly,\n",
"which leads to the normal implementation of attention.\n",
"\n",
"In a large language model,\n",
"the ratio of residual stream dimension, `d_model`, to\n",
"the dimension of a single head, `d_head`, is huge, often 100:1.\n",
"That means each query, key, and value computed at a position\n",
"is a fairly simple, low-dimensional feature of the residual stream at that position.\n",
"\n",
"For visual intuition,\n",
"we compare what a matrix with a rank 100th of full rank looks like,\n",
"relative to a full rank matrix of the same size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LUbojJMiW2C"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"\n",
"low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n",
"full_rank = torch.randn(100, 100)\n",
"plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n",
"plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqBst92-OVka"
},
"source": [
"The pattern in the first matrix is very simple,\n",
"relative to the pattern in the second matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkCGrs9EiVh4"
},
"source": [
"Another feature of low rank transformations is\n",
"that they have a large nullspace or kernel --\n",
"these are directions we can move the input without changing the output.\n",
"\n",
"That means that many changes to the residual stream won't affect the behavior of this head at all."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVz2dQgzhD4p"
},
"source": [
"### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVlzwR03m8mC"
},
"source": [
"The combination of residuality\n",
"(changes are added to the current input)\n",
"and low rank\n",
"(only a small subspace is changed by each head)\n",
"drastically changes the intuition about Transformers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qqjZI2jKe6HH"
},
"source": [
"Rather than being an \"embedding of a token in its context\",\n",
"the residual stream becomes something more like a memory or a scratchpad:\n",
"one layer reads a small bit of information from the stream\n",
"and writes a small bit of information back to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YIBkxlqepjc"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-layer-residual.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RtsKhkLfk00l"
},
"source": [
"The residual stream works like a memory because it is roomy enough\n",
"that these actions need not interfere:\n",
"the subspaces targeted by reads and writes are small relative to the ambient space,\n",
"so they can\n",
"\n",
"Additionally, the dimension of each head is still in the 100s in large models,\n",
"and\n",
"[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n",
"in them, so the number of effectively degrees of freedom is\n",
"actually larger than the dimension.\n",
"This phenomenon allows high-dimensional tensors to serve as\n",
"[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n",
"There are\n",
"[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n",
"\n",
"Together, this means an early layer can write information to the stream\n",
"that can be used by later layers -- by many of them at once, possibly much later.\n",
"Later layers can learn to edit this information,\n",
"e.g. deleting it,\n",
"if doing so reduces the loss,\n",
"but by default the information is preserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EragIygzJg86"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-stream-read-write.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKIaUZjwkpW7"
},
"source": [
"Lastly, the softmax in the attention has a sparsifying effect,\n",
"and so many attention heads are reading from \n",
"just one token and writing to just one other token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dN6VcJqIMKnB"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-token-to-token.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Repeatedly reading information from an external memory\n",
"and using it to decide which operation to perform\n",
"and where to write the results\n",
"is at the core of the\n",
"[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n",
"For a concrete example, the\n",
"[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n",
"includes a dissection of a form of \"pointer arithmetic\"\n",
"that appears in some models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kLFh7Mvnolr"
},
"source": [
"This point of view seems\n",
"very promising for explaining numerous\n",
"otherwise perhaps counterintuitive features of Transformer models.\n",
"\n",
"- This framework predicts lots that Transformers will readily copy-and-paste information,\n",
"which might explain phenomena like\n",
"[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n",
"\n",
"- It also readily explains\n",
"[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n",
"an important component of why Transformers perform well on medium-length texts\n",
"and in few-shot learning.\n",
"\n",
"- Transformers also perform better on reasoning tasks when the text\n",
"[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n",
"is added to their input prompt.\n",
"This is partly due to the fact that that prompt is associated,\n",
"in the dataset, with clearer reasoning,\n",
"and since the models are trained to predict which tokens tend to appear\n",
"after an input, they tend to produce better reasoning with that prompt --\n",
"an explanation purely in terms of sequence modeling.\n",
"But it also gives the Transformer license to generate a large number of tokens\n",
"that act to store intermediate information,\n",
"making for a richer residual stream\n",
"for reading and writing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RyLRzgG-93yB"
},
"source": [
"### Implementation detail: Transformers are position-insensitive by default."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oR6PnrlA_hJ2"
},
"source": [
"In the attention calculation\n",
"each token can query each other token,\n",
"with no regard for order.\n",
"Furthermore, the construction of queries, keys, and values\n",
"is based on the content of the embedding vector,\n",
"which does not automatically include its position.\n",
"\"dog bites man\" and \"man bites dog\" are identical, as in\n",
"[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n",
"\n",
"For most sequences,\n",
"this is unacceptable:\n",
"absolute and relative position matter\n",
"and we cannot use the future to predict the past.\n",
"\n",
"We need to add two pieces to get a Transformer architecture that's usable for next-token prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWHxGJz2-6ZK"
},
"source": [
"First, the simpler piece:\n",
"\"causal\" attention,\n",
"so-named because it ensures that values earlier in the sequence\n",
"are not influenced by later values, which would\n",
"[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0c42xi6URYB4"
},
"source": [
"The most common solution is straightforward:\n",
"we calculate attention between all tokens,\n",
"then throw out non-causal values by \"masking\" them\n",
"(this is before applying the softmax,\n",
"so masking means adding $-\\infty$).\n",
"\n",
"This feels wasteful --\n",
"why are we calculating values we don't need?\n",
"Trying to be smarter would be harder,\n",
"and might rely on operations that aren't as optimized as\n",
"matrix multiplication and addition.\n",
"Furthermore, it's \"only\" twice as many operations,\n",
"so it doesn't even show up in $O$-notation.\n",
"\n",
"A sample attention mask generated by our code base is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NXaWe6pT-9jV"
},
"outputs": [],
"source": [
"from text_recognizer.models import transformer_util\n",
"\n",
"\n",
"attention_mask = transformer_util.generate_square_subsequent_mask(100)\n",
"\n",
"ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n",
"plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n",
"print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This solves our causality problem,\n",
"but we still don't have positional information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZamUE4WIoGS2"
},
"source": [
"The standard technique\n",
"is to add alternating sines and cosines\n",
"of increasing frequency to the embeddings\n",
"(there are\n",
"[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n",
"most notably\n",
"[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n",
"Each position in the sequence is then uniquely identifiable\n",
"from the pattern of these values.\n",
"\n",
"> Furthermore, for the same reason that\n",
" [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n",
" translations, e.g. relative positions, are fairly easy to express as linear transformations\n",
" of sines and cosines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IDG2uOsaELU0"
},
"source": [
"We superimpose this positional information on our embeddings.\n",
"Note that because the model is residual,\n",
"this position information will be by default preserved\n",
"as it passes through the network,\n",
"so it doesn't need to be repeatedly added."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what this positional encoding looks like in our codebase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Zk62Q-a-1Ax"
},
"outputs": [],
"source": [
"PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n",
"\n",
"pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n",
"\n",
"ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n",
"print(pe[:4, :8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep2ClIWvqDms"
},
"source": [
"When we add the positional information to our embeddings,\n",
"both the embedding information and the positional information\n",
"is approximately preserved,\n",
"as can be visually assessed below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJuFjoCzC0Y4"
},
"outputs": [],
"source": [
"fake_embeddings = torch.randn_like(pe) * 0.5\n",
"\n",
"ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n",
"\n",
"fake_embeddings_with_pe = fake_embeddings + pe\n",
"\n",
"plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHIzBxDkEmH8"
},
"source": [
"A [similar technique](https://arxiv.org/abs/2103.06450)\n",
"is used to also incorporate positional information into the image embeddings,\n",
"which are flattened before being fed to the decoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HC1N85wl8dvn"
},
"source": [
"### Learn more about Transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lJwYxkjTk15t"
},
"source": [
"We're only able to give a flavor and an intuition for Transformers here.\n",
"\n",
"To improve your grasp on the nuts and bolts, check out the\n",
"[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n",
"which is surprisingly approachable,\n",
"as far as ML research papers go.\n",
"The\n",
"[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n",
"adds code and commentary to the original paper,\n",
"which makes it even more digestible.\n",
"For something even friendlier, check out the\n",
"[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n",
"by Jay Alammar, which has an accompanying\n",
"[video](https://youtu.be/-QH8fRhqFHM).\n",
"\n",
"Anthropic's work on\n",
"[Transformer Circuits](https://transformer-circuits.pub/),\n",
"summarized above, has some of the best material\n",
"for building theoretical understanding\n",
"and is still being updated with extensions and applications of the framework.\n",
"The\n",
"[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n",
"are a great aid for checking and building your understanding.\n",
"\n",
"But they are fairly math-heavy.\n",
"If you have more of a software engineering background, see\n",
"Transformer Circuits co-author Nelson Elhage's blog post\n",
"[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n",
"\n",
"For a gentler introduction to the intuition for Transformers,\n",
"check out Brandon Rohrer's\n",
"[Transformers From Scratch](https://e2eml.school/transformers.html)\n",
"tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qg7zntJES-aT"
},
"source": [
"An aside:\n",
"the matrix multiplications inside attention dominate\n",
"the big-$O$ runtime of Transformers.\n",
"So trying to make the attention mechanism more efficient, e.g. linear time,\n",
"has generated a lot of research\n",
"(review paper\n",
"[here](https://arxiv.org/abs/2009.06732)).\n",
"Despite drawing a lot of attention, so to speak,\n",
"at the time of writing in mid-2022, these methods\n",
"[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n",
"so it isn't likely to be worth the effort to spend time learning about them\n",
"unless you are a Transformer specialist."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCjXysEJ8g9_"
},
"source": [
"# Using Transformers to read paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsfKWnOvqjva"
},
"source": [
"Our simple convolutional model for text recognition from\n",
"[Lab 02b](https://fsdl.me/lab02b-colab)\n",
"could only handle cleanly-separated characters.\n",
"\n",
"It worked by sliding a LeNet-style CNN\n",
"over the image,\n",
"predicting a character for each step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "njLdzBqy-I90"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist_lines = text_recognizer.data.EMNISTLines()\n",
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"\n",
"# for sliding, see the for loop over range(S)\n",
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0N6yDBQq8ns"
},
"source": [
"But unfortunately for us, handwritten text\n",
"doesn't come in neatly-separated characters\n",
"of equal size, so we trained our model on synthetic data\n",
"designed to work with that model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiqUVbj0sxLr"
},
"source": [
"Now that we have a better model,\n",
"we can work with better data:\n",
"paragraphs from the\n",
"[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oizsOAcKs-dD"
},
"source": [
"The cell uses our `LightningDataModule`\n",
"to download and preprocess this data,\n",
"writing results to disk.\n",
"We can then spin up `DataLoader`s to give us batches.\n",
"\n",
"It can take several minutes to run the first time\n",
"on commodity machines,\n",
"with most time spent extracting the data.\n",
"On subsequent runs,\n",
"the time-consuming operations will not be repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uL9LHbjdsUbm"
},
"outputs": [],
"source": [
"iam_paragraphs = text_recognizer.data.IAMParagraphs()\n",
"\n",
"iam_paragraphs.prepare_data()\n",
"iam_paragraphs.setup()\n",
"xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n",
"\n",
"iam_paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBkFN9bbTm_S"
},
"source": [
"Now that we've got a batch,\n",
"let's take a look at some samples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hqaps8yxtBhU"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import wandb\n",
"\n",
"\n",
"def show(y):\n",
" y = y.detach().cpu() # bring back from accelerator if it's being used\n",
" return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 04: Experiment Management"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How experiment management brings observability to ML model development\n",
"- Which features of experiment management we use in developing the Text Recognizer\n",
"- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 4\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lab contains a large number of embedded iframes\n",
"that benefit from having a wide window.\n",
"The cell below makes the notebook as wide as your browser window\n",
"if `full_width` is set to `True`.\n",
"Full width is the default behavior in Colab,\n",
"so this cell is intended to improve the viewing experience in other Jupyter environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Why experiment management?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To understand why we need experiment management for ML development,\n",
"let's start by running an experiment.\n",
"\n",
"We'll train a new model on a new dataset,\n",
"using the training script `training/run_experiment.py`\n",
"introduced in [Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use a CNN encoder and Transformer decoder, as in\n",
"[Lab 03](https://fsdl.me/lab03-colab),\n",
"but with some changes so we can iterate faster.\n",
"We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n",
"[Lab02b](https://fsdl.me/lab02b-colab),\n",
"and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n",
"from text_recognizer.data import IAMLines # processed version split into individual lines\n",
"from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n",
"\n",
"\n",
"print(IAM.__doc__)\n",
"\n",
"# uncomment a line below for details on either class\n",
"# IAMLines?? \n",
"# LineCNNTransformer??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below will train a model on 10% of the data for two epochs.\n",
"\n",
"It takes up to a few minutes to run on commodity hardware,\n",
"including data download and preprocessing.\n",
"As it's running, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"import torch\n",
"\n",
"\n",
"gpus = int(torch.cuda.is_available()) \n",
"\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the model trains, we're calculating lots of metrics --\n",
"loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n",
"and reporting them to the terminal.\n",
"\n",
"This is achieved by the built-in `.log` method\n",
"([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n",
"of the `LightningModule`,\n",
"and it is a very straightforward way to get basic information about your experiment as it's running\n",
"without leaving the context where you're running it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Learning to read\n",
"[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n",
"is something of a rite of passage for MLEs, but\n",
"let's consider what we can't see here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We're missing all metric values except the most recent --\n",
"we can see them as they stream in, but they're constantly overwritten.\n",
"We also can't associate them with timestamps, steps, or epochs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We also don't see any system metrics.\n",
"We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n",
"without launching a separate process.\n",
"And even if we do, those values will also not be saved and timestamped,\n",
"so we can't correlate them with other things during training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- As we continue to run experiments, changing code and opening new terminals,\n",
"even the information we have or could figure out now will disappear.\n",
"Say you spot a weird error message during training,\n",
"but your session ends and the stdout is gone,\n",
"so you don't know exactly what it was.\n",
"Can you recreate the error?\n",
"Which git branch and commit were you on?\n",
"Did you have any uncommitted changes? Which arguments did you pass?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Also, model checkpoints containing the parameter values have been saved to disk.\n",
"Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n",
"As we run more and more experiments,\n",
"we'll want to slice and dice them to see if,\n",
"say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to save and log all of this information, and more, in order to make our model training\n",
"[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n",
"in short, so that we can understand, make decisions about, and debug our model training\n",
"by looking at logs and source code, without having to recreate it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n",
"1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n",
"2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n",
"3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n",
"4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local Experiment Tracking with Tensorboard"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n",
"and it makes logging very easy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n",
"\n",
"It can also use a logger to push information elsewhere.\n",
"\n",
"By default, we use\n",
"[TensorBoard](https://www.tensorflow.org/tensorboard)\n",
"via the Lightning `TensorBoardLogger`,\n",
"which has been saving results to the local disk.\n",
"\n",
"Let's find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest experiment's directory\n",
"# by hand, you can just copy and paste it from the terminal\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n",
"filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n",
"latest_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!ls -lh {latest_log}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, we need to launch a TensorBoard server --\n",
"much like we need to launch a Jupyter server to use Jupyter notebooks.\n",
"\n",
"The cells below load an extension that lets you use TensorBoard inside of a notebook\n",
"the same way you'd use it from the command line, and then launch it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n",
"\n",
"port = 11717 # pick an open port on your machine\n",
"host = \"0.0.0.0\" # allow connections from the internet\n",
" # watch out! make sure you turn TensorBoard off\n",
"\n",
"%tensorboard --logdir {latest_log} --port {port} --host {host}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see some charts of metrics over time along with some charting controls.\n",
"\n",
"You can click around in this interface and explore it if you'd like,\n",
"but in the next section, we'll see that there are better tools for experiment management."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've run many experiments on this machine,\n",
"you can see all of their results by pointing TensorBoard\n",
"at the whole `lightning_logs` directory,\n",
"rather than just one experiment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For large numbers of experiments, the management experience is not great --\n",
"it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n",
"\n",
"It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n",
"which are important workflows as applications mature and teams grow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n",
"\n",
"If you run into any issues with the above cells failing to launch,\n",
"especially across iterations of this lab, run this cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorboard.manager\n",
"\n",
"# get the process IDs for all tensorboard instances\n",
"pids = [tb.pid for tb in tensorboard.manager.get_all()]\n",
"\n",
"done_with_tensorboard = False\n",
"\n",
"if done_with_tensorboard:\n",
" # kill processes\n",
" for pid in pids:\n",
" !kill {pid} 2> /dev/null\n",
" \n",
" # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n",
" !rm -rf {tensorboard.manager._get_info_dir()}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiment Management with Weights & Biases"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How do we manage experiments when we hit the limits of local TensorBoard?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is powerful and flexible and very scalable,\n",
"but running it requires engineering effort and babysitting --\n",
"you're running a database, writing data to it,\n",
"and layering a web application over it.\n",
"\n",
"This is a fairly common workflow for web developers,\n",
"but not so much for ML engineers.\n",
"\n",
"You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n",
"and it's as simple as running the command `tensorboard dev upload`\n",
"pointed at your logging directory.\n",
"\n",
"But there are strict limits to this free service:\n",
"1GB of tensor data and 1GB of binary data.\n",
"A single Text Recognizer model checkpoint is ~100MB,\n",
"and that's not particularly large for a useful model.\n",
"\n",
"Furthermore, all data is public,\n",
"so if you upload the inputs and outputs of your model,\n",
"anyone who finds the link can see them.\n",
"\n",
"Overall, tensorboard.dev works very well for certain academic and open projects\n",
"but not for industrial ML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To avoid that narrow permissions and limits issue,\n",
"you could use [git LFS](https://git-lfs.github.com/)\n",
"to track the binary data and tensor data,\n",
"which is more likely to be sensitive than metrics.\n",
"\n",
"The Hugging Face ecosystem uses TensorBoard and git LFS.\n",
"\n",
"It includes the Hugging Face Hub, a git server much like GitHub,\n",
"but designed first and foremost for collaboration on models and datasets,\n",
"rather than collaboration on code.\n",
"For example, the Hugging Face Hub\n",
"[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n",
"and officially has\n",
"[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n",
"avoiding the\n",
"[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n",
"that make using git LFS with GitHub expensive.\n",
"\n",
"However, we prefer to avoid mixing software version control and experiment management.\n",
"\n",
"First, using the Hub requires maintaining an additional git remote,\n",
"which is a hard ask for many engineering teams.\n",
"\n",
"Secondly, git-style versioning is an awkward fit for logging --\n",
"is it really sensible to create a new commit for each logging event while you're watching live?\n",
"\n",
"Instead, we prefer to use systems that solve experiment management with _databases_."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n",
"The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n",
"tool is [MLflow](https://github.com/mlflow/mlflow/)\n",
"and there are a number of\n",
"[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n",
"\n",
"These tools generally avoid any need to worry about hosting\n",
"(unless data governance rules require a self-hosted version).\n",
"\n",
"For a sampling of publicly-posted opinions on experiment management tools,\n",
"see these discussions from Reddit:\n",
"\n",
"- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n",
"- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n",
"\n",
"Among these tools, the FSDL recommendation is\n",
"[Weights & Biases](https://wandb.ai),\n",
"which we believe offers\n",
"- the best user experience, both in the Python SDKs and in the graphical interface\n",
"- the best integrations with other tools,\n",
"including\n",
"[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n",
"[Keras](https://docs.wandb.ai/guides/integrations/keras),\n",
"[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n",
"and even\n",
"[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n",
"and\n",
"- the best tools for collaboration.\n",
"\n",
"Below, we'll take care to point out which logging and management features\n",
"are available via generic interfaces in Lightning and which are W&B-specific."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"print(wandb.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding it to our experiment running code is extremely easy,\n",
"relative to the features we get, which is\n",
"one of the main selling points of W&B.\n",
"\n",
"We get most of our new experiment management features just by changing a single variable, `logger`, from\n",
"`TensorboardLogger` to `WandbLogger`\n",
"and adding two lines of code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll see what each of these lines does for us below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this logger is built into and maintained by PyTorch Lightning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"\n",
"WandbLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to complete the rest of this notebook,\n",
"you'll need a Weights & Biases account.\n",
"\n",
"As with GitHub the free tier, for personal, academic, and open source work,\n",
"is very generous.\n",
"\n",
"The Text Recognizer project will fit comfortably within the free tier.\n",
"\n",
"Run the cell below and follow the prompts to log in or create an account or go\n",
"[here](https://wandb.ai/signup)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the cell below to launch an experiment tracked with Weights & Biases.\n",
"\n",
"The experiment can take between 3 and 10 minutes to run.\n",
"In that time, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n",
" --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1\n",
" \n",
"last_expt = wandb.run\n",
"\n",
"wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see some new things in our output.\n",
"\n",
"For example, there's a note from `wandb` that the data is saved locally\n",
"and also synced to their servers.\n",
"\n",
"There's a link to a webpage for viewing the logged data and a name for our experiment --\n",
"something like `dandy-sunset-1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The local logging and cloud syncing happens with minimal impact on performance,\n",
"because `wandb` launches a separate process to listen for events and upload them.\n",
"\n",
"That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, head to the link in the notebook output\n",
"that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n",
"\n",
"There's no need to wait for training to finish.\n",
"\n",
"The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(last_expt.url)\n",
"IFrame(last_expt.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have landed on the run page\n",
"([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n",
"which collects up all of the information for a single experiment into a collection of tabs.\n",
"\n",
"We'll work through these tabs from top to bottom.\n",
"\n",
"Each header is also a link to the documentation for a tab."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n",
"This tab has an icon that looks like `(i)` or 🛈.\n",
"\n",
"The top section of this tab has high-level information about our run:\n",
"- Timing information, like start time and duration\n",
"- System hardware, hostname, and basic environment info\n",
"- Git repository link and state\n",
"\n",
"This information is collected and logged automatically.\n",
"\n",
"The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n",
"and summary metrics.\n",
"\n",
"Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n",
"\n",
"This tab has a line plot icon, something like 📈.\n",
"\n",
"It's also the default page you land on when looking at a W&B run.\n",
"\n",
"Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n",
"\n",
"Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n",
"\n",
"We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n",
"This tab has computer chip icon.\n",
"\n",
"It contains\n",
"- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n",
"- CPU metrics: memory usage, utilization, thread counts\n",
"- Disk and network I/O levels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n",
"This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n",
"\n",
"The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n",
"This tab has an icon that looks like a stylized command prompt, `>_`.\n",
"\n",
"It contains information that was printed to the stdout.\n",
"\n",
"This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n",
"\n",
"Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelSummary\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n",
"\n",
"For more on Lightning `Callback`s, see\n",
"[Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n",
"This tab has a stylized document icon, something like 📄.\n",
"\n",
"You can use this tab to view any files saved with the `wandb.save`.\n",
"\n",
"For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n",
"which we'll discuss shortly.\n",
"\n",
"But a few pieces of information automatically collected by W&B end up in this tab.\n",
"\n",
"Some highlights:\n",
" - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n",
" - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n",
"This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n",
"\n",
"This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n",
"\n",
"We store two kinds of binary files\n",
" - `run_table`s of model inputs and outputs\n",
" - `model` checkpoints\n",
"\n",
"We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interactive Tables of Logged Media"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Returning to the Charts tab,\n",
"notice that we have model inputs and outputs logged in structured tables\n",
"under the train, validation, and test sections.\n",
"\n",
"These tables are interactive as well\n",
"([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n",
"They support basic exploratory data analysis and are compatible with W&B's collaboration features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to charts in our run page, these tables also have their own pages inside the W&B web app."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n",
"table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n",
"\n",
"print(table_data_url)\n",
"IFrame(src=table_data_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Getting this to work requires more effort and more W&B-specific code\n",
"than the other features we've seen so far.\n",
"\n",
"We'll briefly explain the implementation here, for those who are interested.\n",
"\n",
"We use a custom Lightning `Callback`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n",
"\n",
"\n",
"ImageToTextTableLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n",
"\n",
"The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n",
"\n",
"This behavior is sensible for metrics, which are low overhead, but not so much for media,\n",
"where we'd rather subsample and avoid holding on to too much information.\n",
"\n",
"So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n",
"\n",
"The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.lit_models.base import BaseImageToTextLitModel\n",
"\n",
"BaseImageToTextLitModel.add_on_logged_batches??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Projects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Everything we've seen so far has been related to a single run or experiment.\n",
"\n",
"Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n",
"\n",
"We organize our runs into \"projects\" and view them on the W&B \"project page\" \n",
"([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n",
"\n",
"By default in the Lightning integration, the project name is determined based on directory information.\n",
"This default can be over-ridden in the code when creating a `WandbLogger`,\n",
"but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what the project page looks like for a longer-running project with lots of experiments.\n",
"\n",
"The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n",
"\n",
"print(project_url)\n",
"IFrame(src=project_url, width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n",
"\n",
"We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Artifacts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Beyond logging metrics and metadata from runs,\n",
"we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below pulls up all of the artifacts associated with the experiment we just ran."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Click on one of the `model` checkpoints -- the specific version doesn't matter.\n",
"\n",
"There are a number of tabs here.\n",
"\n",
"The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n",
"\n",
"The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n",
"which are added by default by the `WandbLogger`.\n",
"\n",
"The \"Files\" tab contains the actual file contents of the artifact.\n",
"\n",
"On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n",
"including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n",
"\n",
"You can click on these to explore the different versions and even directly compare them.\n",
"\n",
"If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Artifact storage is part of the W&B free tier.\n",
"\n",
"The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n",
"\n",
"The former is sufficient to store ~700 model checkpoints for the Text Recognizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can track your data storage and compare it to your limits at this URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n",
"\n",
"print(storage_tracker_url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Programmatic Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also programmatically access our data and metadata via the `wandb` API\n",
"([docs](https://docs.wandb.ai/guides/track/public-api-guide)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wb_api = wandb.Api()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run = wb_api.run(\"/\".join( # fetch a run given\n",
" [last_expt.entity, # the user or org it was logged to\n",
" last_expt.project, # the \"project\", usually one of several per repo/application\n",
" last_expt.id] # and a unique ID\n",
"))\n",
"\n",
"hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n",
"\n",
"hist.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hist.groupby(\"epoch\")[\"train/loss\"].mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this includes the artifacts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# which artifacts where created and logged?\n",
"artifacts = run.logged_artifacts()\n",
"\n",
"for artifact in artifacts:\n",
" print(f\"artifact of type {artifact.type}: {artifact.name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thanks to our `ImageToTextTableLogger`,\n",
"we can easily recreate training or validation data that came out of our `DataLoader`s,\n",
"which is normally ephemeral:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n",
"artifact_dir = Path(artifact.download(root=\"training/logs\"))\n",
"image_dir = artifact_dir / \"media\" / \"images\"\n",
"\n",
"images = [path for path in image_dir.iterdir()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"from IPython.display import Image\n",
"\n",
"Image(str(random.choice(images)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Advanced W&B API Usage: MLOps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the strengths of a well-instrumented experiment tracking system is that it allows\n",
"automatic relation of information:\n",
"what were the inputs when this model's gradient spiked?\n",
"Which models have been trained on this dataset,\n",
"and what was their performance?\n",
"\n",
"Having access and automation around this information is necessary for \"MLOps\",\n",
"which applies contemporary DevOps principles to ML projects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cells below pull down the training data\n",
"for the model currently running the FSDL Text Recognizer app.\n",
"\n",
"This is just intended as a demonstration of what's possible,\n",
"so don't worry about understanding every piece of this,\n",
"and feel free to skip past it.\n",
"\n",
"MLOps is still a nascent field, and these tools and workflows are likely to change.\n",
"\n",
"For example, just before the course launched, W&B released a\n",
"[Model Registry layer](https://docs.wandb.ai/guides/models)\n",
"on top of artifact logging that aims to improve the developer experience for these workflows."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start from the same project we looked at in the project view:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n",
"\n",
"text_recognizer_project "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we search it for the text recognizer model currently being used in production:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# collect all versions of the text-recognizer ever put into production by...\n",
"\n",
"for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n",
" if art_type.name == \"prod-ready\": # for the prod-ready type\n",
" # and grabbing the text-recognizer\n",
" production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n",
"\n",
"# and then get the one that's currently being tested in CI by...\n",
"for text_recognizer in production_text_recognizers:\n",
" if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n",
" in_prod_text_recognizer = text_recognizer\n",
"\n",
"# view its metadata at the url or in the notebook\n",
"in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n",
"\n",
"print(in_prod_text_recognizer_url)\n",
"IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From its metadata, we can get information about how it was \"staged\" to be put into production,\n",
"and in particular which model checkpoint was used:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staging_run = in_prod_text_recognizer.logged_by()\n",
"\n",
"training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n",
"training_ckpt.name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That checkpoint was logged by a training experiment, which is available as metadata.\n",
"\n",
"We can look at the training run for that model, either here in the notebook or at its URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"training_run = training_ckpt.logged_by()\n",
"print(training_run.url)\n",
"IFrame(src=training_run.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And from there, we can access logs and metadata about training,\n",
"confident that we are working with the model that is actually in production.\n",
"\n",
"For example, we can pull down the data we logged and analyze it locally."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_results = training_run.history(samples=10000)\n",
"training_results.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n",
"training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 10\n",
"training_results[\"validation/loss\"].dropna().iloc[10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The charts and webpages in Weights & Biases\n",
"are substantially more useful than ephemeral stdouts or raw logs on disk.\n",
"\n",
"If you're spun up on the project,\n",
"they accelerate debugging, exploration, and discovery.\n",
"\n",
"If not, they're not so much useful as they are overwhelming.\n",
"\n",
"We need to synthesize the raw logged data into information.\n",
"This helps us communicate our work with other stakeholders,\n",
"preserve knowledge and prevent repetition of work,\n",
"and surface insights faster.\n",
"\n",
"These workflows are supported by the W&B Reports feature\n",
"([docs here](https://docs.wandb.ai/guides/reports)),\n",
"which mix W&B charts and tables with explanatory markdown text and embeds.\n",
"\n",
"Below are some common report patterns and\n",
"use cases and examples of each."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of the examples are from the FSDL Text Recognizer project.\n",
"You can find more of them\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n",
"where we've organized them into a report!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dashboard Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dashboards are a structured subset of the output from one or more experiments,\n",
"designed for quickly surfacing issues or insights,\n",
"like an accuracy or performance regression\n",
"or a change in the data distribution.\n",
"\n",
"Use cases:\n",
"- show the basic state of ongoing experiment\n",
"- compare one experiment to another\n",
"- select the most important charts so you can spin back up into context on a project more quickly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n",
"\n",
"IFrame(src=dashboard_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pull Request Documentation Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In most software codebases,\n",
"pull requests are a key focal point\n",
"for units of work that combine\n",
"short-term communication and long-term information tracking.\n",
"\n",
"In ML codebases, it's more difficult to bring\n",
"sufficient information together to make PRs as useful.\n",
"At FSDL, we like to add documentary\n",
"reports with one or a small number of charts\n",
"that connect logged information in the experiment management system\n",
"to state in the version control software.\n",
"\n",
"Use cases:\n",
"- communication of results within a team, e.g. code review\n",
"- record-keeping that links pull request pages to raw logged info and makes it discoverable\n",
"- improving confidence in PR correctness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n",
"\n",
"IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Blog Post Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With sufficient effort, the logged data in the experiment management system\n",
"can be made clear enough to be consumed,\n",
"sufficiently contextualized to be useful outside the team, and\n",
"even beautiful.\n",
"\n",
"The result is a report that's closer to a blog post than a dashboard or internal document.\n",
"\n",
"Use cases:\n",
"- communication between teams or vertically in large organizations\n",
"- external technical communication for branding and recruiting\n",
"- attracting users or contributors\n",
"\n",
"Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n",
"[Boris Dayma](https://twitter.com/borisdayma)\n",
"and others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n",
"\n",
"IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperparameter Optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Many of our choices, like the depth of our network, the nonlinearities of our layers,\n",
"and the learning rate and other parameters of our optimizer, cannot be\n",
"([easily](https://arxiv.org/abs/1606.04474))\n",
"chosen by descent of the gradient of a loss function.\n",
"\n",
"But these parameters that impact the values of the parameters\n",
"we directly optimize with gradients, or _hyperparameters_,\n",
"can still be optimized,\n",
"essentially by trying options and selecting the values that worked best.\n",
"\n",
"In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n",
"\n",
"Expending more compute can squeeze small amounts of additional validation or test performance\n",
"that makes for impressive results on leaderboards but typically doesn't translate\n",
"into better user experience.\n",
"\n",
"In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n",
"built into your other tooling.\n",
"\n",
"Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n",
"([docs](https://docs.wandb.ai/guides/sweeps)).\n",
"\n",
"It also supports a number of more advanced tools, like\n",
"[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n",
"for early termination of poorly-performing runs.\n",
"\n",
"We can use the same training script and we don't need to run an optimization server.\n",
"\n",
"We just need to write a configuration yaml file\n",
"([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n",
"like the one below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile training/simple-overfit-sweep.yaml\n",
"# first we specify what we're sweeping\n",
"# we specify a program to run\n",
"program: training/run_experiment.py\n",
"# we optionally specify how to run it, including setting default arguments\n",
"command: \n",
" - ${env}\n",
" - ${interpreter}\n",
" - ${program}\n",
" - \"--wandb\"\n",
" - \"--overfit_batches\"\n",
" - \"1\"\n",
" - \"--log_every_n_steps\"\n",
" - \"25\"\n",
" - \"--max_epochs\"\n",
" - \"100\"\n",
" - \"--limit_test_batches\"\n",
" - \"0\"\n",
" - ${args} # these arguments come from the sweep parameters below\n",
"\n",
"# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n",
"method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n",
"metric:\n",
" name: train/loss\n",
" goal: minimize\n",
"parameters: \n",
" # LineCNN hyperparameters\n",
" window_width:\n",
" values: [8, 16, 32, 64]\n",
" window_stride:\n",
" values: [4, 8, 16, 32]\n",
" # Transformer hyperparameters\n",
" tf_layers:\n",
" values: [1, 2, 4, 8]\n",
" # we can also fix some values, just like we set default arguments\n",
" gpus:\n",
" value: 1\n",
" model_class:\n",
" value: LineCNNTransformer\n",
" data_class:\n",
" value: IAMLines\n",
" loss:\n",
" value: transformer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the config we launch a \"controller\":\n",
"a lightweight process that just decides what hyperparameters to try next\n",
"and coordinates the heavierweight training.\n",
"\n",
"This lives on the W&B servers, so there are no headaches about opening ports for communication,\n",
"cleaning up when it's done, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n",
"simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we can launch an \"agent\" to follow the orders of the controller:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# interrupt twice to terminate this cell if it's running too long,\n",
"# it can be over 15 minutes with some hyperparameters\n",
"\n",
"!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n",
"\n",
"If not provided, the agent will run forever for random or Bayesian sweeps\n",
"or until the sweep is terminated, which can be done from the W&B interface."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The agents make for a slick workflow for distributing sweeps across GPUs.\n",
"\n",
"We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n",
"which controls which GPUs are accessible by a process, to launch\n",
"parallel agents on separate GPUs on the same machine."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n",
"# open another terminal\n",
"CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n",
"# and so on\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We include optional exercises with the labs for learners who want to dive deeper on specific topics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟Contribute to a hyperparameter search."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n",
"\n",
"There are ~10,000,000 potential hyperparameter combinations,\n",
"and each takes 30 minutes to test,\n",
"so checking each possibility will take over 500 years of compute time.\n",
"Best get cracking then!\n",
"\n",
"Run the cell below to pull up a dashboard and print the URL where you can check on the current status."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_entity = \"fullstackdeeplearning\"\n",
"sweep_project = \"fsdl-line-recognizer-2022\"\n",
"sweep_id = \"e0eo43eu\"\n",
"sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n",
"\n",
"print(sweep_url)\n",
"IFrame(src=sweep_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also retrieve information about the sweep from the API,\n",
"including the hyperparameters being swept over."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hyperparams = sweep_info.config[\"parameters\"]\n",
"hyperparams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you'd like to contribute to this sweep,\n",
"run the cell below after changing the count to a number greater than 0.\n",
"\n",
"Each iteration runs for 30 minutes if it does not crash,\n",
"e.g. due to out-of-memory errors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"count = 0 # off by default, increase it to join in!\n",
"\n",
"if count:\n",
" !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Write some manual logging in `wandb`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the FSDL Text Recognizer codebase,\n",
"we almost exclusively log to W&B through Lightning,\n",
"rather than through the `wandb` Python SDK.\n",
"\n",
"If you're interested in learning how to use W&B directly, e.g. with another training framework,\n",
"try out this quick exercise that introduces the key players in the SDK."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n",
"\n",
"It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n",
"\n",
"Add W&B metric and artifact logging to this cell:\n",
"- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n",
"- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import os\n",
"import random\n",
"\n",
"import wandb\n",
"\n",
"\n",
"os.makedirs(\"logs\", exist_ok=True)\n",
"\n",
"project = \"trying-wandb\"\n",
"config = {\"steps\": 50}\n",
"\n",
"\n",
"with wandb.init(project=project, config=config) as run:\n",
" steps = wandb.config[\"steps\"]\n",
" \n",
" for ii in range(steps):\n",
" loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n",
" \n",
" with open(\"logs/hello.txt\", \"w\") as f:\n",
" f.write(\"hello from wandb, my dudes!\")\n",
" \n",
" run_id = run.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hello_run = wb_api.run(f\"{project}/{run_id}\")\n",
"\n",
"# check for logged loss data\n",
"if \"loss\" not in hello_run.history().keys():\n",
" print(\"loss not logged 🥲\")\n",
"else:\n",
" print(\"loss logged successfully 🥞\")\n",
" if len(hello_run.history()[\"loss\"]) != steps:\n",
" print(\"loss not logged on all steps 🥲\")\n",
" else:\n",
" print(\"loss logged on all steps 🥞\")\n",
"\n",
"artifacts = hello_run.logged_artifacts()\n",
"\n",
"# check for artifact with the right name\n",
"if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n",
" print(\"hello artifact not logged 🥲\")\n",
"else:\n",
" print(\"hello artifact logged successfully 🥞\")\n",
" # check for the file inside the artifacts\n",
" if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n",
" print(\"could not find hello.txt 🥲\")\n",
" else:\n",
" print(\"hello.txt logged successfully 🥞\")\n",
" \n",
" \n",
"hello_run"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you observe interesting phenomena during training,\n",
"from promising hyperparameter combos to software bugs to strange model behavior,\n",
"turn the charts into a W&B report and share it with the FSDL community or\n",
"[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n",
"with a link to them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# check the sweep_info.config above to see the model and data hyperparameters\n",
"# read through the --help output for all potential arguments\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n",
" --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n",
" --help # remove this line to run an experiment instead of printing help\n",
" \n",
"last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟🌟🌟 Add logging of tensor statistics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to logging model inputs and outputs as human-interpretable media,\n",
"it's also frequently useful to see information about their numerical values."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're interested in learning more about metric calculation and logging with Lightning,\n",
"use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n",
"to add tensor statistic logging to the `LineCNNTransformer`.\n",
"\n",
"`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n",
"\n",
"All three are useful, but start by adding just one."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n",
"- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n",
" - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n",
"- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n",
" - Base your code on the calculation and logging of the `val_cer` metric.\n",
" - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ",
"collapsed_sections": [],
"name": "lab04_experiments.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab07/notebooks/lab05_troubleshooting.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 05: Troubleshooting & Testing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n",
"- How to implement tests for ML training systems in particular\n",
"- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 5\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sThWeTtV6fL_"
},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xFP8lU4nSg1P"
},
"source": [
"# Linting Python and Shell Scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cXbdYfFlPhZ-"
},
"source": [
"### Automatically linting with `pre-commit`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ysqqb2GjvLrz"
},
"source": [
"We want keep our code clean and uniform across developers\n",
"and time.\n",
"\n",
"Applying the cleanliness checks and style rules should be\n",
"as painless and automatic as possible.\n",
"\n",
"For this purpose, we recommend bundling linting tools together\n",
"and enforcing them on all commits with\n",
"[`pre-commit`](https://pre-commit.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvqtZChKvLr0"
},
"source": [
"In addition to running on every commit,\n",
"`pre-commit` separates the model development environment from the environments\n",
"needed for the linting tools, preventing conflicts\n",
"and simplifying maintenance and onboarding."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0XuIuKOXhJl"
},
"source": [
"This cell runs `pre-commit`.\n",
"\n",
"The first time it is run on a machine, it will install the environments for all tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hltYGbpNvLr1"
},
"outputs": [],
"source": [
"!pre-commit run --all-files"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gLw08gIkvLr1"
},
"source": [
"The output lists all the checks that are run and whether they are passed.\n",
"\n",
"Notice there are a number of simple version-control hygiene practices included\n",
"that aren't even specific to Python, much less to machine learning.\n",
"\n",
"For example, several of the checks prevent accidental commits with private keys, large files, \n",
"leftover debugger statements, or merge conflict annotations in them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RHEEjb9kvLr1"
},
"source": [
"These linting actions are configured via\n",
"([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n",
"a YAML file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dgXa8BzrvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8HYc_WbTvLr2"
},
"source": [
"Most of the general cleanliness checks are from hooks built by `pre-commit`.\n",
"\n",
"See the comments and links in the `.pre-commit-config.yaml` for more:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K9rTgRqzvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep repos -A 15"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ptkO7aPvLr2"
},
"source": [
"Let's take a look at the section of the file\n",
"that applies most of our Python style enforcement with\n",
"[`flake8`](https://flake8.pycqa.org/en/latest/):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ALsRKfcevLr3",
"scrolled": true
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_Q0BwQUXbg6"
},
"source": [
"The majority of the style checking behavior we want comes from the\n",
"`additional_dependencies`, which are\n",
"[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n",
"that extend `flake8`'s list of lints.\n",
"\n",
"Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n",
"\n",
"We keep the configuration information for `flake8`\n",
"separate from that for `pre-commit`\n",
"in case we want to use additional tools with `flake8`,\n",
"e.g. if some developers want to integrate it directly into their editor,\n",
"and so that if we change away from `.pre-commit`\n",
"but keep `flake8` we don't have to\n",
"recreate our configuration in a different tool.\n",
"\n",
"As much as possible, codebases should strive for single sources of truth\n",
"and link back to those sources of truth with documentation or comments,\n",
"as in the last line above.\n",
"\n",
"Let's take a look at the contents of `flake8`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "doC_4WQwvLr3"
},
"outputs": [],
"source": [
"!cat .flake8"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Nq6HnyU0M47"
},
"source": [
"There's a lot here! We'll focus on the most important bits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U4PiB8CPvLr3"
},
"source": [
"Linting tools in Python generally work by emitting error codes\n",
"with one or more letters followed by three numbers.\n",
"The `select` argument picks which error codes we want to check for.\n",
"Error codes are matched by prefix,\n",
"so for example `B` matches `BTS101` and\n",
"`G1` matches `G102` and `G199` but not `ARG404`.\n",
"\n",
"Certain codes are `ignore`d in the default `flake8` style,\n",
"which is done via the `ignore` argument,\n",
"and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n",
"For example, we rely on `black` to do our formatting,\n",
"so we ignore some of `flake8`'s formatting codes.\n",
"\n",
"Together, these settings define our project's particular style.\n",
"\n",
"But not every file fits this style perfectly.\n",
"Most of the conventions in `black` and `flake8` come from the style-defining\n",
"[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n",
"which exhorts you to \"know when to be inconsistent\".\n",
"\n",
"To allow ourselves to be inconsistent when we know we should be,\n",
"`flake8` includes `per-file-ignores`,\n",
"which let us ignore specific warnings in specific files.\n",
"This is one of the \"escape valves\"\n",
"that makes style enforcement tolerable.\n",
"We can also `exclude` files in the `pre-commit` config itself.\n",
"\n",
"For details on selecting and ignoring,\n",
"see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n",
"\n",
"For definitions of the error codes from `flake8` itself,\n",
"see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n",
"Individual extensions list their added error codes in their documentation,\n",
"e.g. `darglint` does so\n",
"[here](https://github.com/terrencepreilly/darglint#error-codes)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NL0TpyPsvLr4"
},
"source": [
"The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n",
"\n",
"You can read more about each in their documentation:\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mFsZC0a7vLr4"
},
"source": [
"### Linting via a script and using `shellcheck`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RYjpuFwjXkJc"
},
"source": [
"To avoid needing to think about `pre-commit`\n",
"(was the command `pre-commit run` or `pre-commit check`?)\n",
"while developing locally,\n",
"we might put our linters into a shell script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mXlLFWmavLr4"
},
"outputs": [],
"source": [
"!cat tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PPxHpRIB3nbw"
},
"source": [
"These kinds of short and simple shell scripts are common in projects\n",
"of intermediate size.\n",
"\n",
"They are useful for adding automation and reducing friction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TMuPBpAi2qwl"
},
"source": [
"But these scripts are code,\n",
"and all code is susceptible to bugs and subject to concerns of style consistency."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SQRg3ZqXvLr4"
},
"source": [
"We can't check these scripts with tools that lint Python code,\n",
"so we include a shell script linting tool,\n",
"[`shellcheck`](https://www.shellcheck.net/),\n",
"in our `pre-commit`.\n",
"\n",
"More so than checking for correct style,\n",
"this tool checks for common bugs or surprising behaviors of shells,\n",
"which are unfortunately numerous."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zkfhE1srvLr4"
},
"outputs": [],
"source": [
"script_filename = \"tasks/lint.sh\"\n",
"!pre-commit run shellcheck --files {script_filename}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXU9TRrwvLr4"
},
"source": [
"That script has already been tested, so we don't see any errors.\n",
"\n",
"Try copying over a script you've written yourself or\n",
"even from a popular repo that you like\n",
"(by adding to the notebook directory or by making a cell\n",
"with `%%writefile` at the top)\n",
"and test it by changing the `script_filename`.\n",
"\n",
"You'd be surprised at the classes of subtle bugs possible in bash!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "81MhAL-TvLr5"
},
"source": [
"### Try \"unofficial bash strict mode\" for louder failures in scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hSwhs_zUvLr5"
},
"source": [
"Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n",
"[@redsymbol](https://twitter.com/redsymbol),\n",
"which appear at the top of the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o-j0vSxEvLr5"
},
"outputs": [],
"source": [
"!head -n 3 tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d2iJU5jlvLr5"
},
"source": [
"The core idea of strict mode is to fail more loudly.\n",
"This is a desirable behavior of scripts,\n",
"like the ones we're writing,\n",
"even though it's an undesirable behavior for an interactive shell --\n",
"it would be unpleasant to be logged out every time you hit an error.\n",
"\n",
"`set -u` means scripts fail if a variable's value is `u`nset,\n",
"i.e. not defined.\n",
"Otherwise bash is perfectly happy to allow you to reference undefined variables.\n",
"The result is just an empty string, which can lead to maddeningly weird behavior.\n",
"\n",
"`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n",
"rather than using the exit code of the last command.\n",
"Unix tools are perfectly happy to work on nonsense input,\n",
"like sorting error messages, instead of the filenames you meant to send.\n",
"\n",
"You can read more about these choices\n",
"[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n",
"and considerations for working with other non-conforming scripts in \"strict mode\"\n",
"and for handling resource teardown when scripts error out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s1XqsrU_XWWS"
},
"source": [
"# Testing ML Codebases"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPNzeq3NYF2W"
},
"source": [
"## Testing Python code with `pytests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zq5e_x6gc9Vu"
},
"source": [
"\n",
"ML codebases are Python first and foremost, so first let's get some Python tests going."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DC3GxYz6_R9"
},
"source": [
"At a basic level,\n",
"we can write functions that `assert`\n",
"that our code behaves as expected in\n",
"a given scenario and include it in the same module."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rvd-GNwv63W1"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.metrics import test_character_error_rate\n",
"\n",
"test_character_error_rate??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iVB2TsQS5BTq"
},
"source": [
"The standard tool for testing Python code is\n",
"[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n",
"\n",
"We can use it as a command-line tool in a variety of ways,\n",
"including to execute these kinds of tests.\n",
"\n",
"If passed a filename, `pytest` will look for\n",
"any classes that start with `Test` or\n",
"any functions that start with `test_` and run them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u8sQguyJvLr6",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest text_recognizer/lit_models/metrics.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "92tkBCllvLr6"
},
"source": [
"After the results of the tests (pass or fail) are returned,\n",
"you'll see a report of \"coverage\" from\n",
"[`codecov`](https://about.codecov.io/).\n",
"\n",
"This coverage report tells us which files and how many lines in those files\n",
"were at touched by the testing suite."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PllSUe0s5xvU"
},
"source": [
"We do not actually need to provide the names of files with tests in them to `pytest`\n",
"in order for it to run our tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4qOBHJnTZM9x"
},
"source": [
"By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n",
"\n",
"It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n",
"to separate these from the rest of your code\n",
"in a folder or folders named `tests`,\n",
"rather than scattering them around the repo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "acjsYTNSvLr6"
},
"outputs": [],
"source": [
"!ls text_recognizer/tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WZQQZUF0vLr6"
},
"source": [
"Let's take a look at a specific example:\n",
"the tests for some of our utilities around\n",
"custom PyTorch Lightning `Callback`s."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oS0xKv1evLr6"
},
"outputs": [],
"source": [
"from text_recognizer.tests import test_callback_utils\n",
"\n",
"\n",
"test_callback_utils.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lko8msn-vLr7"
},
"source": [
"Notice that we can easily import this as a module!\n",
"\n",
"That's another benefit of organizing tests into specialized files."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5A85FUNv75Fr"
},
"source": [
"The particular utility we're testing\n",
"here is designed to prevent crashes:\n",
"it checks for a particular type of error and turns it into a warning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jl4-DiVe76sw"
},
"outputs": [],
"source": [
"from text_recognizer.callbacks.util import check_and_warn\n",
"\n",
"check_and_warn??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B6E0MhduvLr7"
},
"source": [
"Error-handling code is a common cause of bugs,\n",
"a fact discovered\n",
"[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n",
"so it's very important to test it well!\n",
"\n",
"We start with a very basic test,\n",
"which does not touch anything\n",
"outside of the Python standard library,\n",
"even though this tool is intended to be used\n",
"with more complex features of third-party libraries,\n",
"like `wandb` and `tensorboard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xx5koQmJvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_simple??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MZe9-JVjvLr7"
},
"source": [
"Here, we are just testing the core logic.\n",
"This test won't catch many bugs,\n",
"but when it does fail, something has gone seriously wrong.\n",
"\n",
"These kinds of tests are important for resolving a bug:\n",
"we learn nearly as much from the tests that passed\n",
"as we did from the tests that failed.\n",
"If this test has failed, possibly along with others,\n",
"we can rule out an issue in one of the large external codebases\n",
"touched in the other tests, saving us lots of time in our troubleshooting.\n",
"\n",
"The reasoning for the test is explained in the docstrings, \n",
"which are close to the code.\n",
"\n",
"Your test suite should be as welcoming\n",
"as the rest of your codebase!\n",
"The people reading it, for example yourself in six months, \n",
"are likely upset and in need of some kindness.\n",
"\n",
"More practically, we want keep our time to resolve errors as short as possible,\n",
"and five minutes to write a good docstring now\n",
"can save five minutes during an outage, when minutes really matter."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Om9k-uXhvLr7"
},
"source": [
"That basic test is a start, but it's not enough by itself.\n",
"There's a specific error case that triggered the addition of this code.\n",
"\n",
"So we test that it's handled as expected."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fjbsb5FvvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_tblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGAIZTUjvLr7"
},
"source": [
"That test can fail if the libraries change around our code,\n",
"i.e. if the `TensorBoardLogger` gets a `log_table` method.\n",
"\n",
"We want to be careful when making assumptions\n",
"about other people's software,\n",
"especially for fast-moving libraries like Lightning.\n",
"If we test that those assumptions hold willy-nilly,\n",
"we'll end up with tests that fail because of\n",
"harmless changes in our dependencies.\n",
"\n",
"Tests that require a ton of maintenance and updating\n",
"without leading to code improvements soak up\n",
"more engineering time than they save\n",
"and cause distrust in the testing suite.\n",
"\n",
"We include this test because `TensorBoardLogger` getting\n",
"a `log_table` method will _also_ change the behavior of our code\n",
"in a breaking way, and we want to catch that before it breaks\n",
"a model training job."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jsy95KAvvLr7"
},
"source": [
"Adding error handling can also accidentally kill the \"happy path\"\n",
"by raising an error incorrectly.\n",
"\n",
"So we explicitly test the _absence of an error_,\n",
"not just its presence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LRlIOkjmvLr8"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_wandblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osiqpLynvLr8"
},
"source": [
"There are more tests we could build, e.g. manipulating classes and testing the behavior,\n",
"testing more classes that might be targeted by `check_and_warn`, or\n",
"asserting that warnings are raised to the command line.\n",
"\n",
"But these three basic tests are likely to catch most changes that would break our code here,\n",
"and they're a lot easier to write than the others.\n",
"\n",
"If this utility starts to get more usage and become a critical path for lots of features, we can always add more!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dm285JE5vLr8"
},
"source": [
"## Interleaving testing and documentation with `doctests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHWQvgA8vLr8"
},
"source": [
"One function of tests is to build user/reader confidence in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wrhiJBXFvLr8"
},
"source": [
"One function of documentation is to build user/reader knowledge in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1vu12LDhvLr8"
},
"source": [
"These functions are related. Let's put them together:\n",
"put code in a docstring and test that code.\n",
"\n",
"This feature is part of the\n",
"Python standard library via the\n",
"[`doctest` module](https://docs.python.org/3/library/doctest.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rmfIOwXd-Qt7"
},
"source": [
"Here's an example from our `torch` utilities.\n",
"\n",
"The `first_appearance` function can be used to\n",
"e.g. quickly look for stop tokens,\n",
"giving the length of each sequence."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZzURGcD9vLr8"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.util import first_appearance\n",
"\n",
"\n",
"first_appearance??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0VtYcJ1WvLr8"
},
"source": [
"Notice that in the \"Examples\" section,\n",
"there's a short block of code formatted as a\n",
"Python interpreter session,\n",
"complete with outputs.\n",
"\n",
"We can copy and paste that code and\n",
"check that we get the right outputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dj4lNOxJvLr9"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y9AWHFoIvLr9"
},
"source": [
"We can run the test with `pytest` by passing a command line argument,\n",
"`--doctest-modules`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JMaAxv5ovLr9"
},
"outputs": [],
"source": [
"!pytest --doctest-modules text_recognizer/lit_models/util.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6-2_aOUfvLr9"
},
"source": [
"With the\n",
"[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n",
"running `doctest`s happens automatically\n",
"when `pytest` is invoked."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "my_keokPvLr9"
},
"source": [
"## Basic tests for data code"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qj3Bq_j2_A8o"
},
"source": [
"ML code can be hard to test\n",
"since it involes very heavy artifacts, like models and data,\n",
"and very expensive jobs, like training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DT5OmgrQvLr9"
},
"source": [
"For testing our data-handling code in the FSDL codebase,\n",
"we mostly just use `assert`s,\n",
"which throw errors when behavior differs from expectation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bdzn5g4TvLr9"
},
"outputs": [],
"source": [
"!grep \"assert\" -r text_recognizer/data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2aTlfu4_vLr-"
},
"source": [
"This isn't great practice,\n",
"especially as a codebase grows,\n",
"because we can't easily know when these are executed\n",
"or incorporate them into\n",
"testing automation and coverage analysis tools."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IaMTdmbZ_mkW"
},
"source": [
"So it's preferable to collect up these assertions of simple data properties\n",
"into tests that are run like our other tests.\n",
"\n",
"The test below checks whether any data is leaking\n",
"between training, validation, and testing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qx7cxiDdvLr-"
},
"outputs": [],
"source": [
"from text_recognizer.tests.test_iam import test_iam_data_splits\n",
"\n",
"\n",
"test_iam_data_splits??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16TJwhd1vLr-"
},
"source": [
"Notice that we were able to load the test into the notebook\n",
"because it is in a module,\n",
"and so we can run it here as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mArITFkYvLr-"
},
"outputs": [],
"source": [
"test_iam_data_splits()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E4F2uaclvLr-"
},
"source": [
"But we're checking something pretty simple here,\n",
"so the new code in each test is just a single line.\n",
"\n",
"What if we wanted to test more complex properties,\n",
"like comparing rows or calculating statistics?\n",
"\n",
"We'll end up writing more complex code that might itself have subtle bugs,\n",
"requiring tests for our tests and suffering from\n",
"\"tester's regress\".\n",
"\n",
"This is the phenomenon,\n",
"named by analogy with\n",
"[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n",
"in sociology of science,\n",
"where the validity of our tests is itself\n",
"up for dispute only resolvable by testing the tests,\n",
"but those tests are themselves possibly invalid."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nUGT06gdvLr-"
},
"source": [
"We cut this Gordian knot by using\n",
"a library or framework that is well-tested.\n",
"\n",
"We recommend checking out\n",
"[`great_expectations`](https://docs.greatexpectations.io/docs/)\n",
"if you're looking for a high-quality data testing tool."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dQ5vNsq3vLr-"
},
"source": [
"Especially with data, some tests are particularly \"heavy\" --\n",
"they take a long time,\n",
"and we might want to run them\n",
"on different machines\n",
"and on a different schedule\n",
"than our other tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xephcb0LvLr-"
},
"source": [
"For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n",
"\n",
"We can't just use a cached version of the data,\n",
"since that won't actually execute the code!\n",
"\n",
"This test will take\n",
"as long to run\n",
"and consume as many resources as\n",
"a full download of the data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YSN4w2EqvLr-"
},
"source": [
"`pytest` allows the separation of tests\n",
"into suites with `mark`s,\n",
"which \"tag\" tests with names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V0rScrcXvLr_",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest --markers | head -n 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lr5Ca7B0vLr_"
},
"source": [
"We can choose to run tests with a given mark\n",
"or to skip tests with a given mark, \n",
"among other basic logical operations around combining and filtering marks,\n",
"with `-m`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xmw-Eb1ZvLr_"
},
"outputs": [],
"source": [
"!wandb login # one test requires wandb authentication\n",
"\n",
"!pytest -m \"not data and not slow\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5LuERxOXX_UJ"
},
"source": [
"## Testing training with memorization tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AnWLN4lRvLsA"
},
"source": [
"Training is the process by which we convert inert data into executable models,\n",
"so it is dependent on both.\n",
"\n",
"We decouple checking whether the script has a critical bug\n",
"from whether the data or model code is broken\n",
"by testing on some basic \"fake data\",\n",
"based on a utility from `torchvision`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k4NIc3uWvLsA"
},
"outputs": [],
"source": [
"from text_recognizer.data import FakeImageData\n",
"\n",
"\n",
"FakeImageData.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deN0swwlvLsA"
},
"source": [
"We then test on the actual data with a smaller version of the real model.\n",
"\n",
"We use the Lightning `--fast_dev_run` feature,\n",
"which sets the number of training, validation, and test batches to `1`.\n",
"\n",
"We use a smaller version so that this test can run in just a few minutes\n",
"on a CPU without acceleration.\n",
"\n",
"That allows us to run our tests in environments without GPUs,\n",
"which saves on costs for executing tests.\n",
"\n",
"Here's the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z4J0_uD9vLsA"
},
"outputs": [],
"source": [
"!cat training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-7u9zS1vLsA",
"scrolled": false
},
"outputs": [],
"source": [
"! ./training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UTzfo11KClV3"
},
"source": [
"The above tests don't actaully check\n",
"whether any learning occurs,\n",
"they just check\n",
"whether training runs mechanically,\n",
"without any errors.\n",
"\n",
"We also need a\n",
"[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n",
"for learning.\n",
"For that we recommending checking whether\n",
"the model can learn the right\n",
"outputs for a single batch --\n",
"to \"memorize\" the outputs for\n",
"a particular input.\n",
"\n",
"This memorization test won't\n",
"catch every bug or issue in training,\n",
"which is notoriously difficult,\n",
"but it will flag\n",
"some of the most serious issues."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DVSp3aAvLsA"
},
"source": [
"The script below runs a memorization test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2DFVVrxpvLsA"
},
"source": [
"It takes up to two arguments:\n",
"a `MAX`imum number of `EPOCHS` to run for and\n",
"a `CRITERION` value of the loss to test against.\n",
"\n",
"The test passes if the loss is lower than the `CRITERION` value\n",
"after the `MAX`imum number of `EPOCHS` has passed."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oEhJH0e5vLsB"
},
"source": [
"The important line in this script is the one that invokes our training script,\n",
"`training/run_experiment.py`.\n",
"\n",
"The arguments to `run_experiment` have been tuned for maximum possible speed:\n",
"turning off regularization, shrinking the model,\n",
"and skipping parts of Lightning that we don't want to test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T-fFs1xEvLsB"
},
"outputs": [],
"source": [
"!cat training/tests/test_memorize_iam.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X-47tUA_YNGe"
},
"source": [
"If you'd like to see what a memorization run looks like,\n",
"flip the `running_memorization` flag to `True`\n",
"and watch the results stream in to W&B.\n",
"\n",
"The cell should run in about ten minutes on a commodity GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GwTEsZwKvLsB"
},
"outputs": [],
"source": [
"%%time\n",
"running_memorization = False\n",
"\n",
"if running_memorization:\n",
" max_epochs = 1000\n",
" loss_criterion = 0.05\n",
" !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Troubleshooting model speed with the PyTorch Profiler"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpbN-Om2Drf-"
},
"source": [
"Testing code is only half the story here:\n",
"we also need to fix the issues that our tests flag.\n",
"This is the process of troubleshooting.\n",
"\n",
"In this lab,\n",
"we'll focus on troubleshooting model performance issues:\n",
"what do to when your model runs too slowly."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NZzwELPXvLsD"
},
"source": [
"Troubleshooting deep neural networks for speed is challenging.\n",
"\n",
"There are at least three different common approaches,\n",
"each with an increasing level of skill required:\n",
"\n",
"1. Follow best practices advice from others\n",
"([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n",
"[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n",
"2. Take code that runs slowly and use empirical observations to iteratively improve it.\n",
"3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n",
"\n",
"For the full stack deep learning engineer,\n",
"the final level is typically out of reach,\n",
"unless you're specializing in the model performance\n",
"part of the stack in particular.\n",
"\n",
"So we recommend reaching the middle level,\n",
"and this segment of the lab walks through the\n",
"tools that make this easier."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_yp87UrFZ8M"
},
"source": [
"Because neural network training involves GPU acceleration,\n",
"generic Python profiling tools like\n",
"[`py-spy`](https://github.com/benfred/py-spy)\n",
"won't work, and\n",
"we'll need tools specialized for tracing and profiling DNN training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yspsYVFGEyZm"
},
"source": [
"In general, these tools are for observing what happens while your code is executing:\n",
"_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n",
"\n",
"Because they help us observe the execution in detail,\n",
"they will also help us understand just what is going on during\n",
"a PyTorch training step in greater detail."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YqXq2hKuvLsE"
},
"source": [
"To support profiling and tracing,\n",
"we've added a new argument to `training/run_experiment.py`, `--profile`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z_GMMViWvLsE"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZldoksHPvLsE"
},
"source": [
"As with experiment management, this relies mostly on features of PyTorch Lightning,\n",
"which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n",
"and we just add a few lines of customization:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F2iJ0_A6vLsE"
},
"outputs": [],
"source": [
"!cat training/run_experiment.py | grep args.profile -A 5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aw3ppgndvLsE"
},
"source": [
"For more on profiling with Lightning, see the\n",
"[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCAmNW3QEtcD"
},
"source": [
"The cell below runs an epoch of training with tracing and profiling turned on\n",
"and then saves the results locally and to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t4o3ylDgr46F",
"scrolled": false
},
"outputs": [],
"source": [
"import glob\n",
"\n",
"import torch\n",
"import wandb\n",
"\n",
"from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n",
"\n",
"\n",
"# make it easier to separate these from training runs\n",
"%env WANDB_JOB_TYPE=profile\n",
"\n",
"batch_size = 16\n",
"num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n",
"gpus = 1 # must be run with accelerator\n",
"\n",
"%run training/run_experiment.py --wandb --profile \\\n",
" --max_epochs=1 \\\n",
" --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n",
" --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n",
" --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n",
"\n",
"latest_expt = wandb.run\n",
"\n",
"try: # add execution trace to logged and versioned binaries\n",
" folder = wandb.run.dir\n",
" trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n",
" trace_file = glob.glob(trace_matcher)[0]\n",
" trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n",
" trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n",
" wandb.log_artifact(trace_at)\n",
"except IndexError:\n",
" print(\"trace not found\")\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ePTkS3EqO5tN"
},
"source": [
"We get out a table of statistics in the terminal,\n",
"courtesy of Lightning.\n",
"\n",
"Each row lists an operation\n",
"and and provides information,\n",
"described in the column headers,\n",
"about the time spent on that operation\n",
"across all the training steps we profiled.\n",
"\n",
"With practice, some useful information can be read out from this table,\n",
"but it's better to start from both a less detailed view,\n",
"in the TensorBoard dashboard,\n",
"and a more detailed view,\n",
"using the Chrome Trace viewer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TzV62f3c7-Bi"
},
"source": [
"## High-level statistics from the PyTorch Profiler in TensorBoard"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mNPKXkYw8NWd"
},
"source": [
"Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CbItwuT88eAV"
},
"outputs": [],
"source": [
"your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n",
"\n",
"print(your_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jE_LooMYHFpF"
},
"source": [
"If at any point you run into issues,\n",
"like the description not matching what you observe,\n",
"check out one of our example runs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "za2zybSwIo5C"
},
"outputs": [],
"source": [
"example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n",
"print(example_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xlrhl1n4HYU6"
},
"source": [
"Once the TensorBoard session has loaded up,\n",
"we are dropped into the Overview\n",
"(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n",
"for an example).\n",
"\n",
"In the top center, we see the **GPU Summary** for our system.\n",
"\n",
"In addition to the name of our GPU,\n",
"there are a few configuration details and top-level statistics.\n",
"They are (tersely) documented\n",
"[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MmBhUDgDLhd1"
},
"source": [
"- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n",
"this is effectively a coarse \"version number\" for your GPU hardware.\n",
"It indexes which features are available,\n",
"with more advanced features being available only at higher compute capabilities.\n",
"It does not directly index the speed or memory of the GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "voUgT6zuLyi0"
},
"source": [
"- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yl-IndtXE4b4"
},
"source": [
"- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n",
"for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n",
"Tensor Cores.\n",
"If you're running on an older GPU without Tensor Cores,\n",
"you should consider upgrading.\n",
"If you're running a more recent GPU but not seeing Tensor Core usage,\n",
"you should switch to single precision floating point numbers,\n",
"which Tensor Cores are specialized on."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XxcUf0bBNXy_"
},
"source": [
"- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n",
"at a lower level than just whether something is running at all,\n",
"as in utilization.\n",
"Unlike utilization, reaching 100% is not generally feasible\n",
"and sometimes not desirable.\n",
"Increasing these numbers requires expertise in\n",
"CUDA programming, so we'll target utilization instead."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A88pQn4YMMKc"
},
"source": [
"- **Execution Summary**: This table and pie chart indicates\n",
"how much time within a profiled step\n",
"was spent in each category.\n",
"The value for \"kernel\" execution here\n",
"is equal to the GPU utilization,\n",
"and we want that number to be as close to 100%\n",
"as possible.\n",
"This summary helps us know which\n",
"other operations are taking time,\n",
"like memory being copied between CPU and GPU (`memcpy`)\n",
"or `DataLoader`s executing on the CPU,\n",
"so we can decide where the bottleneck is."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6qjW1RlTQRPv"
},
"source": [
"At the very bottom, you'll find a\n",
"**Performance Recommendation**\n",
"tab that sometimes suggests specific methods for improving performance.\n",
"\n",
"If this tab makes suggestions, you should certainly take them!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pWY5AhrcRQmJ"
},
"source": [
"For more on using the profiler in TensorBoard,\n",
"including some of the other, more detailed views\n",
"available view the \"Views\" dropdown menu, see\n",
"[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQwrPY_H77H8"
},
"source": [
"## Going deeper with the Chrome Trace Viewer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yhwo7fslvLsH"
},
"source": [
"So far, we've seen summary-level information about our training steps\n",
"in the table from Lightning and in the TensorBoard Overview.\n",
"These give aggregate statistics about the computations that occurred,\n",
"but understanding how to interpret those statistics\n",
"and use them to speed up our networks\n",
"requires understanding just what is\n",
"happening in our training step.\n",
"\n",
"Fundamentally,\n",
"all computations are processes that unfold in time.\n",
"\n",
"If we want to really understand our training step,\n",
"we need to display it that way:\n",
"what operations were occurring,\n",
"on both the CPU and GPU,\n",
"at each moment in time during the training step.\n",
"\n",
"This information on timing is collected in the trace.\n",
"One of the best tools for viewing the trace over time\n",
"is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wUkZItxYc20A"
},
"source": [
"Let's tour the trace we just logged\n",
"with an aim to really understanding just\n",
"what is happening when we call\n",
"`training_step`\n",
"and by extension `.forward`, `.backward`, and `optimizer.step`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9w9F2UA7Qctg"
},
"source": [
"The Chrome Trace Viewer is built into W&B,\n",
"so we can view our traces in their interface.\n",
"\n",
"The cell below embeds the trace inside the notebook,\n",
"but you may wish to open it separately,\n",
"with the \"Open page\" button or by navigating to the URL,\n",
"so that you can interact with it\n",
"as you read the description below.\n",
"Display directly on W&B is also a bit less temperamental\n",
"than display on W&B inside a notebook.\n",
"\n",
"Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n",
"so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n",
"It also can interact poorly with browser extensions (e.g. ad blockers),\n",
"so you may need to deactivate them temporarily in order to see it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OMUs4aby6Rfd"
},
"outputs": [],
"source": [
"trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n",
"trace_url = trace_files_url + \"training_step.pt.trace.json\"\n",
"\n",
"example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n",
"\n",
"print(trace_url)\n",
"IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qNVpGeQtQjMG"
},
"source": [
"> **Heads up!** We're about to do a tour of the\n",
"> precise details of the tracing information logged\n",
"> during the execution of the training code.\n",
"> The only way to learn how to troubleshoot model performance\n",
"> empirically is to look at the details,\n",
"> but the details depend on the precise machine being used\n",
"> -- GPU and CPU and RAM.\n",
"> That means even within Colab,\n",
"> these details change from session to session.\n",
"> So if you don't observe a phenomenon or feature\n",
"> described in the tour below, check out\n",
"> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n",
"> on W&B while reading through the next section of the lab,\n",
"> and return to your trace once you understand the trace viewer better at the end.\n",
"> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n",
"> can sometimes be a bit janky."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kXMcBhnCgdN_"
},
"source": [
"This trace reveals, in nanosecond-level detail,\n",
"what's going on inside of a `training_step`\n",
"on both the GPU and the CPU.\n",
"\n",
"Time is on the horizontal axis.\n",
"Colored bars represent method calls,\n",
"and the methods called by a method are placed underneath it vertically,\n",
"a visualization known as an\n",
"[icicle chart](https://www.brendangregg.com/flamegraphs.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "67BsNzDfVIeg"
},
"source": [
"Let's orient ourselves with some gross features:\n",
"the forwards pass,\n",
"GPU kernel execution,\n",
"the backwards pass,\n",
"and the optimizer step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBEFgtRCKqrh"
},
"source": [
"### The forwards pass"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nYhiWesVMjK"
},
"source": [
"Type in `resnet` to the search bar in the top-right.\n",
"\n",
"This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n",
"\n",
"It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n",
"\n",
"You can click the arrows next to that tile to partially collapse these blocks.\n",
"\n",
"Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n",
"It should be at roughly the same height.\n",
"\n",
"Clear the search bar so that the trace is in color.\n",
"Zoom in on the area of the forwards pass\n",
"using the \"zoom\" tool in the floating toolbar,\n",
"so you can see more detail.\n",
"The zoom tool is indicated by a two-headed arrow\n",
"pointing into and out of the screen.\n",
"\n",
"Switch to the \"drag\" tool,\n",
"represented by a four-headed arrow.\n",
"Click-and-hold to use this tool to focus\n",
"on different parts of the timeline\n",
"and click on the individual colored boxes\n",
"to see details about a particular method call.\n",
"\n",
"As we go down in the icicle chart,\n",
"we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n",
"to much more precise `cudnn` and `cuda` operations\n",
"(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n",
"\n",
"`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n",
"is the tensor math library in PyTorch\n",
"that links to specific backends like `cudnn`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fq181ybIvLsH"
},
"source": [
"### GPU kernel execution"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbkWp5aKvLsH"
},
"source": [
"Towards the bottom, you should see a section labeled \"GPU\".\n",
"The label appears on the far left.\n",
"\n",
"Within it, you'll see one or more \"`stream`s\".\n",
"These are units of work on a GPU,\n",
"akin loosely to threads on the CPU.\n",
"\n",
"When there are colored bars in this area,\n",
"the GPU is doing work of some kind.\n",
"The fraction of this bar that is filled in with color\n",
"is the same as the \"GPU Utilization %\" we've seen previously.\n",
"So the first thing to visually assess\n",
"in a trace view of PyTorch code\n",
"is what fraction of this area is filled with color.\n",
"\n",
"In CUDA, work is queued up to be\n",
"placed into streams and completed, on the GPU,\n",
"in a distributed and asynchronous manner.\n",
"\n",
"The selection of which work to do\n",
"is happening on the CPU,\n",
"and that's what we were looking at above.\n",
"\n",
"The CPU and the GPU have to work together to coordinate\n",
"this work.\n",
"\n",
"Type `cuda` into the search bar and you'll see these coordination operations happening:\n",
"`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n",
"\n",
"Running the same PyTorch model\n",
"with the same high level operations like `Conv2d` in different versions of PyTorch,\n",
"on different GPUs, and even on tensors of different sizes will result\n",
"in different choices of concrete kernel operation,\n",
"e.g. different matrix multiplication algorithms.\n",
"\n",
"Type `sync` into the search bar and you'll see places where either work on the GPU\n",
"or work on the CPU needs to await synchronization,\n",
"e.g. copying data from the CPU to the GPU\n",
"or the CPU waiting to decide what to do next\n",
"on the basis of the contents of a tensor.\n",
"\n",
"If you see a \"sync\" block above an area\n",
"where the stream on the GPU is empty,\n",
"you've got a performance bottleneck due to synchronization\n",
"between the CPU and GPU.\n",
"\n",
"To resolve the bottleneck,\n",
"head up the icicle chart until you reach the recognizable\n",
"PyTorch modules and operations.\n",
"Find where they are called in your PyTorch module.\n",
"That's a good place to review your code to understand why the synchronization is happening\n",
"and removing it if it's not necessary."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XeMPbu_jvLsI"
},
"source": [
"### The backwards pass\n",
"\n",
"Type in `backward` into the search bar.\n",
"\n",
"This will highlight components of our backwards pass.\n",
"\n",
"If you read it from left to right,\n",
"you'll see that it begins by calculating the loss\n",
"(`NllLoss2DBackward` in the search bar if you can't find it)\n",
"and ends by doing a `ConvolutionBackward`,\n",
"the first layer of the ResNet.\n",
"It is, indeed, backwards.\n",
"\n",
"Like the forwards pass,\n",
"the backwards pass also involves the CPU\n",
"telling the GPU which kernels to run.\n",
"It's typically run in a separate\n",
"thread from the forwards pass,\n",
"so you'll see it separated out from the forwards pass\n",
"in the trace viewer.\n",
"\n",
"Generally, there's no need to specifically optimize the backwards pass --\n",
"removing bottlenecks in the forwards pass results in a fast backwards pass.\n",
"\n",
"One reason why is that these two passes are just\n",
"\"transposes\" of one another,\n",
"so they share a lot of properties,\n",
"and bottlenecks in one become bottlenecks in the other.\n",
"We can choose to optimize either one of the two.\n",
"But the forwards pass is under our direct control,\n",
"so it's easier for us to reason about.\n",
"\n",
"Another reason is that the forwards pass is more likely to have bottlenecks.\n",
"The forwards pass is a dynamic process,\n",
"with each line of Python adding more to the compute graph.\n",
"Backwards passes, on the other hand, use a static compute graph,\n",
"the one just defined by the forwards pass,\n",
"so more optimizations are possible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gWiDw0vCvLsI"
},
"source": [
"### The optimizer step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ndfkzEdnvLsI"
},
"source": [
"Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n",
"\n",
"As with the two passes,\n",
"we are still using the CPU\n",
"to launch kernels on the GPU.\n",
"But now the CPU is looping,\n",
"in Python, over the parameters\n",
"and applying the ADAM updates rules to each.\n",
"\n",
"We now know enough to see that\n",
"this is not great for our GPU utilization:\n",
"there are many areas of gray\n",
"in between the colored bars\n",
"in the GPU stream in this area.\n",
"\n",
"In the time it takes CUDA to multiply\n",
"thousands of numbers,\n",
"Python has not yet finished cleaning up\n",
"after its request for that multiplication.\n",
"\n",
"As of writing in August 2022,\n",
"more efficient optimizers are not a stable part of PyTorch (v1.12), but\n",
"[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n",
"and stable implementations outside of PyTorch.\n",
"The standard implementations are in\n",
"[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n",
"not to be confused with the\n",
"[Apex Optimizers Project](https://www.apexoptimizers.com/),\n",
"which is a collection of fitness-themed cheetah NFTs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WX0jxeafvLsI"
},
"source": [
"## Take-aways for PyTorch performance bottleneck troubleshooting"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CugD-bK2vLsI"
},
"source": [
"Our goal here was to learn some basic principles and tools for bottlenecking\n",
"the most common issues and the lowest-hanging fruit in PyTorch code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SwHwJkVMHYGA"
},
"source": [
"\n",
"Here's an overview in terms of a \"host\",\n",
"generally the CPU,\n",
"and a \"device\", here the GPU.\n",
"\n",
"- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n",
"- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n",
" - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n",
"- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n",
"stores information about the contents of tensors,\n",
"not just their metadata."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gntx28p9cBP5"
},
"source": [
"Towards that goal, we viewed the trace to get an understanding of\n",
"what's going on inside a PyTorch training step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AKvZGPnkeXvq"
},
"source": [
"Here's what we've means in terms of troubleshooting bottlenecks.\n",
"\n",
"We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n",
"before the previous kernel finishes.\n",
"\n",
"Ideally, the CPU is actually getting far _ahead_ of execution\n",
"on the GPU.\n",
"If the CPU makes it all the way through the backwards pass before the GPU is done,\n",
"that's great!\n",
"The GPU(s) are the expensive part,\n",
"and it's easy to use multiprocessing so that\n",
"the CPU has other things to do.\n",
"\n",
"This helps explain at least one common piece of advice:\n",
"the larger our batches are,\n",
"the more work the GPU has to do for the same work done by the CPU,\n",
"and so the better our utilization will be."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMztpa-TccH4"
},
"source": [
"We operationalize our desire to never be waiting on the CPU with a simple metric:\n",
"**100% GPU utilization**, meaning a kernel is running at all times.\n",
"\n",
"This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n",
"\n",
"You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7kYBygfScR6z"
},
"source": [
"Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n",
"1. **The CPU is too weak**.\n",
"Because so much of the discussion around DNN performance is about GPUs,\n",
"it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n",
"_Resolution_:\n",
"Use nice CPUs, like\n",
"[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n",
"2. **Too much Python during the `training_step`**.\n",
"Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n",
"that takes longer than running an entire layer.\n",
"_Resolution_:\n",
"Look for low utilization areas of the trace\n",
"and check what's happening on the CPU at that time\n",
"and carefully review the Python code being executed.\n",
"3. **Unnecessary Host/Device synchronization**.\n",
"If one of your operations depends on the values in a tensor,\n",
"like `if xs.mean() >= 0`,\n",
"you'll induce a synchronization between\n",
"the host and the device and possibly lead\n",
"to an expensive and slow copy of data.\n",
"_Resolution_:\n",
"Replace these operations as much as possible\n",
"with purely array-based calculations.\n",
"4. **Bottlenecking on the DataLoader**.\n",
"In addition to coordinating the work on the GPU,\n",
"CPUs often perform heavy data operations,\n",
"including communication over the network\n",
"and writing to/reading from disk.\n",
"These are generally done in parallel to the forwards\n",
"and backwards passes,\n",
"but if they don't finish before that happens,\n",
"they will become the bottleneck.\n",
"_Resolution_:\n",
"Get better hardware for compute,\n",
"memory, and network.\n",
"For software solutions, the answer \n",
"is a bit more complex and application-dependent.\n",
"For generic tips, see\n",
"[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n",
"in the PyTorch forums.\n",
"For techniques in computer vision, see\n",
"[the FFCV library](https://github.com/libffcv/ffcv)\n",
"and for techniques in NLP, see e.g.\n",
"[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n",
"and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i2WYS8bQvLsJ"
},
"source": [
"### Further steps in making DNNs go brrrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T0wW2_lRKfY1"
},
"source": [
"It's important to note that utilization\n",
"is just an easily measured metric\n",
"that can reveal common bottlenecks.\n",
"Having high utilization does not automatically mean\n",
"that your performance is fully optimized.\n",
"\n",
"For example,\n",
"synchronization events between GPUs\n",
"are counted as kernels,\n",
"so a deadlock during distributed training\n",
"can show up as 100% utilization,\n",
"despite literally no useful work occurring.\n",
"\n",
"Just switching to \n",
"double precision floats, `--precision=64`,\n",
"will generally lead to much higher utilization.\n",
"The GPU operations take longer\n",
"for roughly the same amount of CPU effort,\n",
"but the added precision brings no benefit.\n",
"\n",
"In particular, it doesn't make for models\n",
"that perform better on our correctness metrics,\n",
"like loss and accuracy.\n",
"\n",
"Another useful yardstick to add\n",
"to utilization is examples per second,\n",
"which incorporates how quickly the model is processing data examples\n",
"and calculating gradients.\n",
"\n",
"But really,\n",
"the gold star is _decrease in loss per second_.\n",
"This metric connects model design choices\n",
"and hyperparameters with purely engineering concerns,\n",
"so it disrespects abstraction barriers\n",
"and doesn't generally lead to actionable recommendations,\n",
"but it is, in the end, the real goal:\n",
"make the loss go down faster so we get better models sooner."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EFzPsplfdo_o"
},
"source": [
"For PyTorch internals abstractly,\n",
"see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"For more on performance considerations in PyTorch,\n",
"see [Horace He's blog post](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yq6-S6TC38AY"
},
"source": [
"### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n",
"\n",
"One of the most important features for making\n",
"PyTorch run quickly is the\n",
"`MultiprocessingDataLoader`,\n",
"which executes batching of data in a separate process\n",
"from the forwards and backwards passes.\n",
"\n",
"By default in PyTorch,\n",
"this feature is actually turned off,\n",
"via the `DataLoader` argument `num_workers`\n",
"having a default value of `0`,\n",
"but we set the `DEFAULT_NUM_WORKERS`\n",
"to a value based on the number of CPUs\n",
"available on the system running the code.\n",
"\n",
"Re-run the profiling cell,\n",
"but set `num_workers` to `0`\n",
"to turn off multiprocessing.\n",
"\n",
"Compare and contrast the two traces,\n",
"both for total runtime\n",
"(see the time axis at the top of the trace)\n",
"and for utilization.\n",
"\n",
"If you're unable to run the profiles,\n",
"see the results\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n",
"which juxtaposes two traces,\n",
"with in-process dataloading on the left and\n",
"multiprocessing dataloading on the right."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T2i_a5eVeIoA"
},
"source": [
"The file below incorrectly implements and then incorrectly tests\n",
"a simple PyTorch utility for adding five to every entry of a tensor\n",
"and then calculating the sum.\n",
"\n",
"Even worse, it does it with horrible style!\n",
"\n",
"The cells below apply our linting checks\n",
"(after automatically fixing the formatting)\n",
"and run the test.\n",
"\n",
"Fix all of the lints,\n",
"implement the function correctly,\n",
"and then implement some basic tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wSon2fB5VVM_"
},
"source": [
"- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aYiRvU4HA84t"
},
"outputs": [],
"source": [
"%%writefile training/fixme.py\n",
"import torch\n",
"from training import run_experiment\n",
"from numpy import *\n",
"import random\n",
"from pathlib import Path\n",
"\n",
"\n",
"\n",
"\n",
"def add_five_and_sum(tensor):\n",
" # this function is not implemented right,\n",
" # but it's supposed to add five to all tensor entries and sum them up\n",
" return 1\n",
"\n",
"def test_add_five_and_sum():\n",
" # and this test isn't right either! plus this isn't exactly a docstring\n",
" all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n",
" all_fives = 5 * all_ones\n",
" assert False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EXJpmvuzT1w0"
},
"outputs": [],
"source": [
"!pre-commit run black --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SRO-oJfdUrcQ"
},
"outputs": [],
"source": [
"!cat training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jM8NHxVbSEQD"
},
"outputs": [],
"source": [
"!pre-commit run --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kj0VMBSndtkc"
},
"outputs": [],
"source": [
"!pytest training/fixme.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab05_troubleshooting.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab07/notebooks/lab06_data.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 06: Data Annotation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How the `IAM` handwriting dataset is structured on disk and how it is processed into an ML-friendly format\n",
"- How to setup a [Label Studio](https://labelstud.io/) data annotation server\n",
"- Just how messy data really is"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 6\n",
"\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpvaHz9TEGwV"
},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gsXpeXi2EGwV"
},
"outputs": [],
"source": [
"from IPython.display import IFrame\n",
"\n",
"\n",
"IFrame(src=\"https://fsdl.me/2022-lab-06-video-embed\", width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XTkKzEMNR8XZ"
},
"source": [
"# `IAMParagraphs`: From annotated data to a PyTorch `Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3mQLbjuiwZuj"
},
"source": [
"We've used the `text_recognizer.data` submodule\n",
"and its `LightningDataModule`s -- `IAMLines` and `IAMParagraphs`\n",
"for lines and paragraphs of handwritten text\n",
"from the\n",
"[IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database).\n",
"\n",
"These classes convert data from a database-friendly format\n",
"designed for storage and transfer into the\n",
"format our DNNs expect:\n",
"PyTorch `Tensor`s.\n",
"\n",
"In this section,\n",
"we'll walk through that process in detail.\n",
"\n",
"In the following section,\n",
"we'll see how data\n",
"goes from signals measured in the world\n",
"to the format we consume here."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "499c23a6"
},
"source": [
"## Dataset structure on disk"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a3438d2e"
},
"source": [
"We begin by downloading the raw data to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "18900eec"
},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM\n",
"\n",
"iam = IAM()\n",
"iam.prepare_data()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a332f359"
},
"source": [
"The `IAM` dataset is downloaded as zip file\n",
"and then unzipped:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d6c44266"
},
"outputs": [],
"source": [
"from text_recognizer.metadata.iam import DL_DATA_DIRNAME\n",
"\n",
"\n",
"iam_dir = DL_DATA_DIRNAME\n",
"!ls {iam_dir}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8463c2d1"
},
"source": [
"The unzipped dataset is not simple a flat directory of files.\n",
"\n",
"Instead, there are a number of subfolders,\n",
"each of which contains a particular type of data or metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "536924f7"
},
"outputs": [],
"source": [
"iamdb = iam_dir / \"iamdb\"\n",
"\n",
"!du -h {iamdb}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b745a594"
},
"source": [
"For example, the `task` folder contains metadata about canonical dataset splits:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "84c21f75"
},
"outputs": [],
"source": [
"!find {iamdb / \"task\"} | grep \"\\\\.txt$\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mEb0Pdm4vIHe"
},
"source": [
"We find the images of handwritten text in the `forms` folder.\n",
"\n",
"An individual \"datapoint\" in `IAM` is a \"form\",\n",
"because the humans whose hands wrote the text were prompted to write on \"forms\",\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "945d5e3a"
},
"outputs": [],
"source": [
"from IPython.display import Image\n",
"\n",
"\n",
"form_fn, = !find {iamdb}/forms | grep \".jpg$\" | sort | head -n 1\n",
"\n",
"print(form_fn)\n",
"Image(filename=form_fn, width=\"360\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b9e9e384"
},
"source": [
"Meanwhile, the `xml` files contain the data annotations,\n",
"written out as structured text:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6add5c5a"
},
"outputs": [],
"source": [
"xml_fn, = !find {iamdb}/xml | grep \"\\.xml$\" | sort | head -n 1\n",
"\n",
"!cat {xml_fn} | grep -A 100 \"handwritten-part\" | grep \"
", "", " and ", *tokens, " and ", *tokens, ""]
self.end_index = self.inverse_mapping["",
""]
self.end_token = inverse_mapping[""]
self.end_token = inverse_mapping[""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 01: Deep Neural Networks in PyTorch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How to write a basic neural network from scratch in PyTorch\n",
"- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c7bFQ20LbLB"
},
"source": [
"At its core, PyTorch is a library for\n",
"- doing math on arrays\n",
"- with automatic calculation of gradients\n",
"- that is easy to accelerate with GPUs and distribute over nodes.\n",
"\n",
"Much of the time,\n",
"we work at a remove from the core features of PyTorch,\n",
"using abstractions from `torch.nn`\n",
"or from frameworks on top of PyTorch.\n",
"\n",
"This tutorial builds those abstractions up\n",
"from core PyTorch,\n",
"showing how to go from basic iterated\n",
"gradient computation and application\n",
"to a solid training and validation loop.\n",
"It is adapted from the PyTorch tutorial\n",
"[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n",
"\n",
"We assume familiarity with the fundamentals of ML and DNNs here,\n",
"like gradient-based optimization and statistical learning.\n",
"For refreshing on those, we recommend\n",
"[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n",
"or\n",
"[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 1\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6wJ8r7BTPB-t"
},
"source": [
"# Getting data and making `Tensor`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MpRyqPPYie-F"
},
"source": [
"Before we can build a model,\n",
"we need data.\n",
"\n",
"The code below uses the Python standard library to download the\n",
"[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n",
"from the internet.\n",
"\n",
"The data used to train state-of-the-art models these days\n",
"is generally too large to be stored on the disk of any single machine\n",
"(to say nothing of the RAM!),\n",
"so fetching data over a network is a common first step in model training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsokTZTMJ3x6"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import requests\n",
"\n",
"\n",
"def download_mnist(path):\n",
" url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
"\n",
" if not (path / filename).exists():\n",
" content = requests.get(url + filename).content\n",
" (path / filename).open(\"wb\").write(content)\n",
"\n",
" return path / filename\n",
"\n",
"\n",
"data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n",
"path = data_path / \"downloaded\" / \"vector-mnist\"\n",
"path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"datafile = download_mnist(path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-S0es1DujOyr"
},
"source": [
"Larger data consumes more resources --\n",
"when reading, writing, and sending over the network --\n",
"so the dataset is compressed\n",
"(`.gz` extension).\n",
"\n",
"Each piece of the dataset\n",
"(training and validation inputs and outputs)\n",
"is a single Python object\n",
"(specifically, an array).\n",
"We can persist Python objects to disk\n",
"(also known as \"serialization\")\n",
"and load them back in\n",
"(also known as \"deserialization\")\n",
"using the `pickle` library\n",
"(`.pkl` extension)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QZosCF1xJ3x7"
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"\n",
"\n",
"def read_mnist(path):\n",
" with gzip.open(path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
" return x_train, y_train, x_valid, y_valid\n",
"\n",
"x_train, y_train, x_valid, y_valid = read_mnist(datafile)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIYUbKgmknDf"
},
"source": [
"PyTorch provides its own array type,\n",
"the `torch.Tensor`.\n",
"The cell below converts our arrays into `torch.Tensor`s.\n",
"\n",
"Very roughly speaking, a \"tensor\" in ML\n",
"just means the same thing as an\n",
"\"array\" elsewhere in computer science.\n",
"Terminology is different in\n",
"[physics](https://physics.stackexchange.com/a/270445),\n",
"[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n",
"and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n",
"but here the term \"tensor\" is intended to connote\n",
"an array that might have more than two dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ea5d3Ggfkhea"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D0AMKLxGkmc_"
},
"source": [
"Tensors are defined by their contents:\n",
"they are big rectangular blocks of numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yPvh8c_pkl5A"
},
"outputs": [],
"source": [
"print(x_train, y_train, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4UOYvwjFqdzu"
},
"source": [
"Accessing the contents of `Tensor`s is called \"indexing\",\n",
"and uses the same syntax as general Python indexing.\n",
"It always returns a new `Tensor`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9zGDAPXVqdCm"
},
"outputs": [],
"source": [
"y_train[0], x_train[0, ::2]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QhJcOr8TmgmQ"
},
"source": [
"PyTorch, like many libraries for high-performance array math,\n",
"allows us to quickly and easily access metadata about our tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4ENirftAnIVM"
},
"source": [
"The most important pieces of metadata about a `Tensor`,\n",
"or any array, are its _dimension_\n",
"and its _shape_.\n",
"\n",
"The dimension specifies how many indices you need to get a number\n",
"out of an array."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mhaN6qW0nA5t"
},
"outputs": [],
"source": [
"x_train.ndim, y_train.ndim"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pYEk13yoGgz"
},
"outputs": [],
"source": [
"x_train[0, 0], y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rv2WWNcHkEeS"
},
"source": [
"For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n",
"For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yZ6j-IGPJ3x7"
},
"outputs": [],
"source": [
"n, c = x_train.shape\n",
"print(x_train.shape)\n",
"print(y_train.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H-HFN9WJo6FK"
},
"source": [
"This metadata serves a similar purpose for `Tensor`s\n",
"as type metadata serves for other objects in Python\n",
"(and other programming languages).\n",
"\n",
"That is, types tell us whether an object is an acceptable\n",
"input for or output of a function.\n",
"Many functions on `Tensor`s, like indexing,\n",
"matrix multiplication,\n",
"can only accept as input `Tensor`s of a certain shape and dimension\n",
"and will return as output `Tensor`s of a certain shape and dimension.\n",
"\n",
"So printing `ndim` and `shape` to track\n",
"what's happening to `Tensor`s during a computation\n",
"is an important piece of the debugging toolkit!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wCjuWKKNrWGM"
},
"source": [
"We won't spend much time here on writing raw array math code in PyTorch,\n",
"nor will we spend much time on how PyTorch works.\n",
"\n",
"> If you'd like to get better at writing PyTorch code,\n",
"try out\n",
"[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n",
"We wrote a bit about what these puzzles reveal about programming\n",
"with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n",
"\n",
"> If you'd like to get a better understanging of the internals\n",
"of PyTorch, check out\n",
"[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"As we'll see below,\n",
"`torch.nn` provides most of what we need\n",
"for building deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Li5e_jiJpLSI"
},
"source": [
"The `Tensor`s inside of the `x_train` `Tensor`\n",
"aren't just any old blocks of numbers:\n",
"they're images of handwritten digits.\n",
"The `y_train` `Tensor` contains the identities of those digits.\n",
"\n",
"Let's take a look at a random example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4VsHk6xNJ3x8"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"import random\n",
"\n",
"import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n",
"\n",
"import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n",
"\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx]\n",
"\n",
"print(y_train[idx]) # the label of the image\n",
"wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PC3pwoJ9s-ts"
},
"source": [
"We want to build a deep network that can take in an image\n",
"and return the number that's in the image.\n",
"\n",
"We'll build that network\n",
"by fitting it to `x_train` and `y_train`.\n",
"\n",
"We'll first do our fitting with just basic `torch` components and Python,\n",
"then we'll add in other `torch` gadgets and goodies\n",
"until we have a more realistic neural network fitting loop.\n",
"\n",
"Later in the labs,\n",
"we'll see how to even more quickly build\n",
"performant, robust fitting loops\n",
"that have even more features\n",
"by using libraries built on top of PyTorch."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DTLdqCIGJ3x6"
},
"source": [
"# Building a DNN using only `torch.Tensor` methods and Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8D8Xuh2xui3o"
},
"source": [
"One of the really great features of PyTorch\n",
"is that writing code in PyTorch feels\n",
"very similar to writing other code in Python --\n",
"unlike other deep learning frameworks\n",
"that can sometimes feel like their own language\n",
"or programming paradigm.\n",
"\n",
"This fact can sometimes be obscured\n",
"when you're using lots of library code,\n",
"so we start off by just using `Tensor`s and the Python standard library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tOV0bxySJ3x9"
},
"source": [
"## Defining the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLH_zUWkw3W0"
},
"source": [
"We'll make the simplest possible neural network:\n",
"a single layer that performs matrix multiplication,\n",
"and adds a vector of biases.\n",
"\n",
"We'll need values for the entries of the matrix,\n",
"which we generate randomly.\n",
"\n",
"We also need to tell PyTorch that we'll\n",
"be taking gradients with respect to\n",
"these `Tensor`s later, so we use `requires_grad`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1c21c8XQJ3x-"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"\n",
"\n",
"weights = torch.randn(784, 10) / math.sqrt(784)\n",
"weights.requires_grad_()\n",
"bias = torch.zeros(10, requires_grad=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZC8A01sytm2"
},
"source": [
"We can combine our beloved Python operators,\n",
"like `+` and `*` and `@` and indexing,\n",
"to define the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Eoymwooyq0-"
},
"outputs": [],
"source": [
"def linear(x: torch.Tensor) -> torch.Tensor:\n",
" return x @ weights + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5tIRHR_HxeZf"
},
"source": [
"We need to normalize our model's outputs with a `softmax`\n",
"to get our model to output something we can use\n",
"as a probability distribution --\n",
"the probability that the network assigns to each label for the image.\n",
"\n",
"For that, we'll need some `torch` math functions,\n",
"like `torch.sum` and `torch.exp`.\n",
"\n",
"We compute the logarithm of that softmax value\n",
"in part for numerical stability reasons\n",
"and in part because\n",
"[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WuZRGSr4J3x-"
},
"outputs": [],
"source": [
"def log_softmax(x: torch.Tensor) -> torch.Tensor:\n",
" return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n",
"\n",
"def model(xb: torch.Tensor) -> torch.Tensor:\n",
" return log_softmax(linear(xb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-pBI4pOM011q"
},
"source": [
"Typically, we split our dataset up into smaller \"batches\" of data\n",
"and apply our model to one batch at a time.\n",
"\n",
"Since our dataset is just a `Tensor`,\n",
"we can pull that off just with indexing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXsHak23J3x_"
},
"outputs": [],
"source": [
"bs = 64 # batch size\n",
"\n",
"xb = x_train[0:bs] # a batch of inputs\n",
"outs = model(xb) # outputs on that batch\n",
"\n",
"print(outs[0], outs.shape) # outputs on the first element of the batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VPrG9x1DJ3x_"
},
"source": [
"## Defining the loss and metrics"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zEwPJmgZ1HIp"
},
"source": [
"Our model produces outputs, but they are mostly wrong,\n",
"since we set the weights randomly.\n",
"\n",
"How can we quantify just how wrong our model is,\n",
"so that we can make it better?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JY-2QZEu1Xc7"
},
"source": [
"We want to compare the outputs and the target labels,\n",
"but the model outputs a probability distribution,\n",
"and the labels are just numbers.\n",
"\n",
"We can take the label that had the highest probability\n",
"(the index of the largest output for each input,\n",
"aka the `argmax` over `dim`ension `1`)\n",
"and treat that as the model's prediction\n",
"for the digit in the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_sHmDw_cJ3yC"
},
"outputs": [],
"source": [
"def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n",
" preds = torch.argmax(out, dim=1)\n",
" return (preds == yb).float().mean()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PfrDJb2EF_uz"
},
"source": [
"If we run that function on our model's `out`put`s`,\n",
"we can confirm that the random model isn't doing well --\n",
"we expect to see that something around one in ten predictions are correct."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8l3aRMNaJ3yD"
},
"outputs": [],
"source": [
"yb = y_train[0:bs]\n",
"\n",
"acc = accuracy(outs, yb)\n",
"\n",
"print(acc)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxRfO1HQ3VYs"
},
"source": [
"We can calculate how good our network is doing,\n",
"so are we ready to use optimization to make it do better?\n",
"\n",
"Not yet!\n",
"To train neural networks, we use gradients\n",
"(aka derivatives).\n",
"So all of the functions we use need to be differentiable --\n",
"in particular they need to change smoothly so that a small change in input\n",
"can only cause a small change in output.\n",
"\n",
"Our `argmax` breaks that rule\n",
"(if the values at index `0` and index `N` are really close together,\n",
"a tiny change can change the output by `N`)\n",
"so we can't use it.\n",
"\n",
"If we try to run our `backward`s pass to get a gradient,\n",
"we get a `RuntimeError`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g5AnK4md4kxv"
},
"outputs": [],
"source": [
"try:\n",
" acc.backward()\n",
"except RuntimeError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HJ4WWHHJ460I"
},
"source": [
"So we'll need something else:\n",
"a differentiable function that gets smaller when\n",
"our model gets better, aka a `loss`.\n",
"\n",
"The typical choice is to maximize the\n",
"probability the network assigns to the correct label.\n",
"\n",
"We could try doing that directly,\n",
"but more generally,\n",
"we want the model's output probability distribution\n",
"to match what we provide it -- \n",
"here, we claim we're 100% certain in every label,\n",
"but in general we allow for uncertainty.\n",
"We quantify that match with the\n",
"[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n",
"\n",
"Cross entropies\n",
"[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n",
"including more familiar functions like the\n",
"mean squared error and the mean absolute error.\n",
"\n",
"We can calculate it directly from the outputs and target labels\n",
"using some cute tricks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-k20rW_rJ3yA"
},
"outputs": [],
"source": [
"def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n",
" return -output[range(target.shape[0]), target].mean()\n",
"\n",
"loss_func = cross_entropy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZa1DSGN7zPK"
},
"source": [
"With random guessing on a dataset with 10 equally likely options,\n",
"we expect our loss value to be close to the negative logarithm of 1/10:\n",
"the amount of entropy in a uniformly random digit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1bKRJ90MJ3yB"
},
"outputs": [],
"source": [
"print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hTgFTdVgAGJW"
},
"source": [
"Now we can call `.backward` without PyTorch complaining:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1LH_ZpY0_e_6"
},
"outputs": [],
"source": [
"loss = loss_func(outs, yb)\n",
"\n",
"loss.backward()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji0FA3dDACUk"
},
"source": [
"But wait, where are the gradients?\n",
"They weren't returned by `loss` above,\n",
"so where could they be?\n",
"\n",
"They've been stored in the `.grad` attribute\n",
"of the parameters of our model,\n",
"`weights` and `bias`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zgtyyhp__s8a"
},
"outputs": [],
"source": [
"bias.grad"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dWTYno0JJ3yD"
},
"source": [
"## Defining and running the fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TTR2Qo9F8ZLQ"
},
"source": [
"We now have all the ingredients we need to fit a neural network to data:\n",
"- data (`x_train`, `y_train`)\n",
"- a network architecture with parameters (`model`, `weights`, and `bias`)\n",
"- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n",
"\n",
"We can put them together into a training loop\n",
"just using normal Python features,\n",
"like `for` loops, indexing, and function calls:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SzNZVEiVJ3yE"
},
"outputs": [],
"source": [
"lr = 0.5 # learning rate hyperparameter\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs): # loop over the data repeatedly\n",
" for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n",
" start_idx = ii * bs # we are ii batches in, each of size bs\n",
" end_idx = start_idx + bs # and we want the next bs entires\n",
"\n",
" # pull batches from x and from y\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
"\n",
" # run model\n",
" pred = model(xb)\n",
"\n",
" # get loss\n",
" loss = loss_func(pred, yb)\n",
"\n",
" # calculate the gradients with a backwards pass\n",
" loss.backward()\n",
"\n",
" # update the parameters\n",
" with torch.no_grad(): # we don't want to track gradients through this part!\n",
" # SGD learning rule: update with negative gradient scaled by lr\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
"\n",
" # ACHTUNG: PyTorch doesn't assume you're done with gradients\n",
" # until you say so -- by explicitly \"deleting\" them,\n",
" # i.e. setting the gradients to 0.\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9J-BfH1e_Jkx"
},
"source": [
"To check whether things are working,\n",
"we confirm that the value of the `loss` has gone down\n",
"and the `accuracy` has gone up:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mHgGCLaVJ3yE"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E1ymEPYdcRHO"
},
"source": [
"We can also run the model on a few examples\n",
"to get a sense for how it's doing --\n",
"always good for detecting bugs in our evaluation metrics!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O88PWejlcSTL"
},
"outputs": [],
"source": [
"# re-execute this cell for more samples\n",
"idx = random.randint(0, len(x_train))\n",
"example = x_train[idx:idx+1]\n",
"\n",
"out = model(example)\n",
"\n",
"print(out.argmax())\n",
"wandb.Image(example.reshape(28, 28)).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7L1Gq1N_J3yE"
},
"source": [
"# Refactoring with core `torch.nn` components"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EE5nUXMG_Yry"
},
"source": [
"This works!\n",
"But it's rather tedious and manual --\n",
"we have to track what the parameters of our model are,\n",
"apply the parameter updates to each one individually ourselves,\n",
"iterate over the dataset directly, etc.\n",
"\n",
"It's also very literal:\n",
"many assumptions about our problem are hard-coded in the loop.\n",
"If our dataset was, say, stored in CSV files\n",
"and too large to fit in RAM,\n",
"we'd have to rewrite most of our training code.\n",
"\n",
"For the next few sections,\n",
"we'll progressively refactor this code to\n",
"make it shorter, cleaner,\n",
"and more extensible\n",
"using tools from the sublibraries of PyTorch:\n",
"`torch.nn`, `torch.optim`, and `torch.utils.data`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BHEixRsbJ3yF"
},
"source": [
"## Using `torch.nn.functional` for stateless computation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9k94IlN58lWa"
},
"source": [
"First, let's drop that `cross_entropy` and `log_softmax`\n",
"we implemented ourselves --\n",
"whenever you find yourself implementing basic mathematical operations\n",
"in PyTorch code you want to put in production,\n",
"take a second to check whether the code you need's not out\n",
"there in a library somewhere.\n",
"You'll get fewer bugs and faster code for less effort!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sP-giy1a9Ct4"
},
"source": [
"Both of those functions operated on their inputs\n",
"without reference to any global variables,\n",
"so we find their implementation in `torch.nn.functional`,\n",
"where stateless computations live."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vfWyJW1sJ3yF"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"loss_func = F.cross_entropy\n",
"\n",
"def model(xb):\n",
" return xb @ weights + bias"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kqYIkcvpJ3yF"
},
"outputs": [],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vXFyM1tKJ3yF"
},
"source": [
"## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PInL-9sbCKnv"
},
"source": [
"Perhaps the biggest issue with our setup is how we're handling state.\n",
"\n",
"The `model` function refers to two global variables: `weights` and `bias`.\n",
"These variables are critical for it to run,\n",
"but they are defined outside of the function\n",
"and are manipulated willy-nilly by other operations.\n",
"\n",
"This problem arises because of a fundamental tension in\n",
"deep neural networks.\n",
"We want to use them _as functions_ --\n",
"when the time comes to make predictions in production,\n",
"we put inputs in and get outputs out,\n",
"just like any other function.\n",
"But neural networks are fundamentally stateful,\n",
"because they are _parameterized_ functions,\n",
"and fiddling with the values of those parameters\n",
"is the purpose of optimization.\n",
"\n",
"PyTorch's solution to this is the `nn.Module` class:\n",
"a Python class that is callable like a function\n",
"but tracks state like an object.\n",
"\n",
"Whatever `Tensor`s representing state we want PyTorch\n",
"to track for us inside of our model\n",
"get defined as `nn.Parameter`s and attached to the model\n",
"as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A34hxhd0J3yF"
},
"outputs": [],
"source": [
"from torch import nn\n",
"\n",
"\n",
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n",
" self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n",
" self.bias = nn.Parameter(torch.zeros(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pFD_sIRaFbbx"
},
"source": [
"We define the computation that uses that state\n",
"in the `.forward` method.\n",
"\n",
"Using some behind-the-scenes magic,\n",
"this method gets called if we treat\n",
"the instantiated `nn.Module` like a function by\n",
"passing it arguments.\n",
"You can give similar special powers to your own classes\n",
"by defining `__call__` \"magic dunder\" method\n",
"on them.\n",
"\n",
"> We've separated the definition of the `.forward` method\n",
"from the definition of the class above and\n",
"attached the method to the class manually below.\n",
"We only do this to make the construction of the class\n",
"easier to read and understand in the context this notebook --\n",
"a neat little trick we'll use a lot in these labs.\n",
"Normally, we'd just define the `nn.Module` all at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QAKK3dlFT9w"
},
"outputs": [],
"source": [
"def forward(self, xb: torch.Tensor) -> torch.Tensor:\n",
" return xb @ self.weights + self.bias\n",
"\n",
"MNISTLogistic.forward = forward\n",
"\n",
"model = MNISTLogistic() # instantiated as an object\n",
"print(model(xb)[:4]) # callable like a function\n",
"loss = loss_func(model(xb), yb) # composable like a function\n",
"loss.backward() # we can still take gradients through it\n",
"print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r-Yy2eYTHMVl"
},
"source": [
"But how do we apply our updates?\n",
"Do we need to access `model.weights.grad` and `model.weights`,\n",
"like we did in our first implementation?\n",
"\n",
"Luckily, we don't!\n",
"We can iterate over all of our model's `torch.nn.Parameters`\n",
"via the `.parameters` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vM59vE-5JiXV"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tbFCdWBkNft0"
},
"source": [
"That means we no longer need to assume we know the names\n",
"of the model's parameters when we do our update --\n",
"we can reuse the same loop with different models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hA925fIUK0gg"
},
"source": [
"Let's wrap all of that up into a single function to `fit` our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q9NxJZTOJ3yG"
},
"outputs": [],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" for p in model.parameters(): # finds params automatically\n",
" p -= p.grad * lr\n",
" model.zero_grad()\n",
"\n",
"fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mjmsb94mK8po"
},
"source": [
"and check that we didn't break anything,\n",
"i.e. that our model still gets accuracy much higher than 10%:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vo65cLS5J3yH"
},
"outputs": [],
"source": [
"print(accuracy(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fxYq2sCLJ3yI"
},
"source": [
"# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "95c67wZCMynl"
},
"source": [
"Our model's state is being handled respectably,\n",
"our fitting loop is 2x shorter,\n",
"and we can train different models if we'd like.\n",
"\n",
"But we're not done yet!\n",
"Many steps we're doing manually above\n",
"are already built in to `torch`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CE2VFjDZJ3yI"
},
"source": [
"## Using `torch.nn.Linear` for the model definition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvcnrz2uJ3yI"
},
"source": [
"As with our hand-rolled `cross_entropy`\n",
"that could be profitably replaced with\n",
"the industrial grade `nn.functional.cross_entropy`,\n",
"we should replace our bespoke linear layer\n",
"with something made by experts.\n",
"\n",
"Instead of defining `nn.Parameters`,\n",
"effectively raw `Tensor`s, as attributes\n",
"of our `nn.Module`,\n",
"we can define other `nn.Module`s as attributes.\n",
"PyTorch assigns the `nn.Parameters`\n",
"of any child `nn.Module`s to the parent, recursively.\n",
"\n",
"These `nn.Module`s are reusable --\n",
"say, if we want to make a network with multiple layers of the same type --\n",
"and there are lots of them already defined:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l-EKdhXcPjq2"
},
"outputs": [],
"source": [
"import textwrap\n",
"\n",
"print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KbIIQMaBQC45"
},
"source": [
"We want the humble `nn.Linear`,\n",
"which applies the same\n",
"matrix multiplication and bias operation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JHwS-1-rJ3yJ"
},
"outputs": [],
"source": [
"class MNISTLogistic(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n",
"\n",
" def forward(self, xb):\n",
" return self.lin(xb) # call nn.Linear.forward here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mcb0UvcmJ3yJ"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"print(loss_func(model(xb), yb)) # loss is still close to 2.3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hcjV8A2QjQJ"
},
"source": [
"We can see that the `nn.Linear` module is a \"child\"\n",
"of the `model`,\n",
"and we don't see the matrix of weights and the bias vector:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yKkU-GIPOQq4"
},
"outputs": [],
"source": [
"print(*list(model.children()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kUdhpItWQui_"
},
"source": [
"but if we ask for the model's `.parameters`,\n",
"we find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G1yGOj2LNDsS"
},
"outputs": [],
"source": [
"print(*list(model.parameters()), sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DFlQyKl6J3yJ"
},
"source": [
"## Applying gradients with `torch.optim.Optimizer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IqImMaenJ3yJ"
},
"source": [
"Applying gradients to optimize parameters\n",
"and resetting those gradients to zero\n",
"are very common operations.\n",
"\n",
"So why are we doing that by hand?\n",
"Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n",
"we don't have to --\n",
"we just need to point a `torch.optim.Optimizer`\n",
"at the parameters of our model.\n",
"\n",
"While we're at it, we can also use a more sophisticated optimizer --\n",
"`Adam` is a common first choice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f5AUNLEKJ3yJ"
},
"outputs": [],
"source": [
"from torch import optim\n",
"\n",
"\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jK9dy0sNJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4yk9re3HJ3yK"
},
"source": [
"## Organizing data with `torch.utils.data.Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ap3fcZpTIqJ"
},
"source": [
"We're also manually handling the data.\n",
"First, we're independently and manually aligning\n",
"the inputs, `x_train`, and the outputs, `y_train`.\n",
"\n",
"Aligned data is important in ML.\n",
"We want a way to combine multiple data sources together\n",
"and index into them simultaneously.\n",
"\n",
"That's done with `torch.utils.data.Dataset`.\n",
"Just inherit from it and implement two methods to support indexing:\n",
"`__getitem__` and `__len__`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HPj25nkoVWRi"
},
"source": [
"We'll cheat a bit here and pull in the `BaseDataset`\n",
"class from the `text_recognizer` library,\n",
"so that we can start getting some exposure\n",
"to the codebase for the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NpltQ-4JJ3yK"
},
"outputs": [],
"source": [
"from text_recognizer.data.util import BaseDataset\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV1bc4R5Vz0N"
},
"source": [
"The cell below will pull up the documentation for this class,\n",
"which effectively just indexes into the two `Tensor`s simultaneously.\n",
"\n",
"It can also apply transformations to the inputs and targets.\n",
"We'll see that later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XUWJ8yIWU28G"
},
"outputs": [],
"source": [
"BaseDataset??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMQDHJNzWMtf"
},
"source": [
"This makes our code a tiny bit cleaner:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6iyqG4kEJ3yK"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"opt = configure_optimizer(model)\n",
"\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTtRPp_iJ3yL"
},
"source": [
"## Batching up data with `torch.utils.data.DataLoader`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FPnaMyokWSWv"
},
"source": [
"We're also still manually building our batches.\n",
"\n",
"Making batches out of datasets is a core component of contemporary deep learning training workflows,\n",
"so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n",
"\n",
"We just need to hand our `Dataset` to the `DataLoader`\n",
"and choose a `batch_size`.\n",
"\n",
"We can tune that parameter and other `DataLoader` arguments,\n",
"like `num_workers` and `pin_memory`,\n",
"to improve the performance of our training loop.\n",
"For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n",
"[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aqXX7JGCJ3yL"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iWry2CakJ3yL"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, train_dataloader: DataLoader):\n",
" opt = configure_optimizer(self)\n",
"\n",
" for epoch in range(epochs):\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
"MNISTLogistic.fit = fit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pfdSJBIXT8o"
},
"outputs": [],
"source": [
"model = MNISTLogistic()\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAs8-3IfJ3yL"
},
"source": [
"Compare the ten line `fit` function with our first training loop (reproduced below) --\n",
"much cleaner _and_ much more powerful!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_a51dZrLJ3yL"
},
"source": [
"```python\n",
"lr = 0.5 # learning rate\n",
"epochs = 2 # how many epochs to train for\n",
"\n",
"for epoch in range(epochs):\n",
" for ii in range((n - 1) // bs + 1):\n",
" start_idx = ii * bs\n",
" end_idx = start_idx + bs\n",
" xb = x_train[start_idx:end_idx]\n",
" yb = y_train[start_idx:end_idx]\n",
" pred = model(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" with torch.no_grad():\n",
" weights -= weights.grad * lr\n",
" bias -= bias.grad * lr\n",
" weights.grad.zero_()\n",
" bias.grad.zero_()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiQe3SEWyZo4"
},
"source": [
"## Swapping in another model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KykHpZEWyZo4"
},
"source": [
"To see that our new `.fit` is more powerful,\n",
"let's use it with a different model.\n",
"\n",
"Specifically, let's draw in the `MLP`,\n",
"or \"multi-layer perceptron\" model\n",
"from the `text_recognizer` library\n",
"in our codebase."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1FtGJg1CyZo4"
},
"outputs": [],
"source": [
"from text_recognizer.models.mlp import MLP\n",
"\n",
"\n",
"MLP.fit = fit # attach our fitting loop"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kJiP3a-8yZo4"
},
"source": [
"If you look in the `.forward` method of the `MLP`,\n",
"you'll see that it uses\n",
"some modules and functions we haven't seen, like\n",
"[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n",
"but otherwise fits the interface of our training loop:\n",
"the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hj-0UdJwyZo4"
},
"outputs": [],
"source": [
"MLP.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FS7dxQ4VyZo4"
},
"source": [
"If we look at the constructor, `__init__`,\n",
"we see that the `nn.Module`s (`fc` and `dropout`)\n",
"are initialized and attached as attributes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x0NpkeA8yZo5"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uygy5HsUyZo5"
},
"source": [
"We also see that we are required to provide a `data_config`\n",
"dictionary and can optionally configure the module with `args`.\n",
"\n",
"For now, we'll only do the bare minimum and specify\n",
"the contents of the `data_config`:\n",
"the `input_dims` for `x` and the `mapping`\n",
"from class index in `y` to class label,\n",
"which we can see are used in the `__init__` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y6BEl_I-yZo5"
},
"outputs": [],
"source": [
"digits_to_9 = list(range(10))\n",
"data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n",
"data_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bEuNc38JyZo5"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CWQK2DWWyZo6"
},
"source": [
"The resulting `MLP` is a bit larger than our `MNISTLogistic` model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zs1s6ahUyZo8"
},
"outputs": [],
"source": [
"model.fc1.weight"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JVLkK78FyZo8"
},
"source": [
"But that doesn't matter for our fitting loop,\n",
"which happily optimizes this model on batches from the `train_dataloader`,\n",
"though it takes a bit longer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-DItXLoyZo9"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"print(\"before training:\", loss_func(model(xb), yb))\n",
"\n",
"train_ds = BaseDataset(x_train, y_train)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)\n",
"fit(model, train_dataloader)\n",
"\n",
"print(\"after training:\", loss_func(model(xb), yb))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QgTv2yzJ3yM"
},
"source": [
"# Extra goodies: data organization, validation, and acceleration"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vx-CcCesbmyw"
},
"source": [
"Before we've got a DNN fitting loop that's welcome in polite company,\n",
"we need three more features:\n",
"organized data loading code, validation, and GPU acceleration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LWja5aDJ3yN"
},
"source": [
"## Making the GPU go brrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7juxQ_Kp-Tx0"
},
"source": [
"Everything we've done so far has been on\n",
"the central processing unit of the computer, or CPU.\n",
"When programming in Python,\n",
"it is on the CPU that\n",
"almost all of our code becomes concrete instructions\n",
"that cause a machine move around electrons."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R25L3z8eAWIO"
},
"source": [
"That's okay for small-to-medium neural networks,\n",
"but computation quickly becomes a bottleneck that makes achieving\n",
"good performance infeasible.\n",
"\n",
"In general, the problem of CPUs,\n",
"which are general purpose computing devices,\n",
"being too slow is solved by using more specialized accelerator chips --\n",
"in the extreme case, application-specific integrated circuits (ASICs)\n",
"that can only perform a single task,\n",
"the hardware equivalents of\n",
"[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n",
"[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n",
"\n",
"Luckily, really excellent chips\n",
"for accelerating deep learning are readily available\n",
"as a consumer product:\n",
"graphics processing units (GPUs),\n",
"which are designed to perform large matrix multiplications in parallel.\n",
"Their name derives from their origins\n",
"applying large matrix multiplications to manipulate shapes and textures\n",
"in for graphics engines for video games and CGI.\n",
"\n",
"If your system has a GPU and the right libraries installed\n",
"for `torch` compatibility,\n",
"the cell below will print information about its state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xxy-Gt9wJ3yN"
},
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" !nvidia-smi\n",
"else:\n",
" print(\"☹️\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x6qAX1OECiWk"
},
"source": [
"PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n",
"even simultaneously, which can be critical for high performance.\n",
"\n",
"So once we start using acceleration, we need to be more precise about where the\n",
"data inside our `Tensor`s lives --\n",
"on which physical `torch.device` it can be found.\n",
"\n",
"On compatible systems, the cell below will\n",
"move all of the model's parameters `.to` the GPU\n",
"(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n",
"and then move a batch of inputs and targets there as well\n",
"before applying the model and calculating the loss.\n",
"\n",
"To confirm this worked, look for the name of the device in the output of the cell,\n",
"alongside other information about the loss `Tensor`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jGkpfEmbJ3yN"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"\n",
"model.to(device)\n",
"\n",
"loss_func(model(xb.to(device)), yb.to(device))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-zdPR06eDjIX"
},
"source": [
"Rather than rewrite our entire `.fit` function,\n",
"we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n",
"\n",
"Specifically,\n",
"we can provide a `transform` that is called on the inputs\n",
"and a `target_transform` that is called on the labels\n",
"before they are returned.\n",
"In the FSDL codebase,\n",
"this feature is used for data preparation, like\n",
"reshaping, resizing,\n",
"and normalization.\n",
"\n",
"We'll use this as an opportunity to put the `Tensor`s on the appropriate device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m8WQS9Zo_Did"
},
"outputs": [],
"source": [
"def push_to_device(tensor):\n",
" return tensor.to(device)\n",
"\n",
"train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
"train_dataloader = DataLoader(train_ds, batch_size=bs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmg9HMSZFmqR"
},
"source": [
"We don't need to change anything about our fitting code to run it on the GPU!\n",
"\n",
"Note: given the small size of this model and the data,\n",
"the speedup here can sometimes be fairly moderate (like 2x).\n",
"For larger models, GPU acceleration can easily lead to 50-100x faster iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v1TVc06NkXrU"
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"model.fit(train_dataloader)\n",
"\n",
"print(loss_func(model(push_to_device(xb)), push_to_device(yb)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7thbdjKTjAD"
},
"source": [
"Writing high performance GPU-accelerated neural network code is challenging.\n",
"There are many sharp edges, so the default\n",
"strategy is imitation (basing all work on existing verified quality code)\n",
"and conservatism bordering on paranoia about change.\n",
"For a casual introduction to some of the core principles, see\n",
"[Horace He's blogpost](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LnpbEVE5J3yM"
},
"source": [
"## Adding validation data and organizing data code with a `DataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqYHjiG8b_4J"
},
"source": [
"Just doing well on data you've seen before is not that impressive --\n",
"the network could just memorize the label for each input digit.\n",
"\n",
"We need to check performance on a set of data points that weren't used\n",
"directly to optimize the model,\n",
"commonly called the validation set."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7e6z-Fh8dOnN"
},
"source": [
"We already downloaded one up above,\n",
"but that was all the way at the beginning of the notebook,\n",
"and I've already forgotten about it.\n",
"\n",
"In general, it's easy for data-loading code,\n",
"the redheaded stepchild of the ML codebase,\n",
"to become messy and fall out of sync.\n",
"\n",
"A proper `DataModule` collects up all of the code required\n",
"to prepare data on a machine,\n",
"sets it up as a collection of `Dataset`s,\n",
"and turns those `Dataset`s into `DataLoader`s,\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0WxgRa2GJ3yM"
},
"outputs": [],
"source": [
"class MNISTDataModule:\n",
" url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n",
" filename = \"mnist.pkl.gz\"\n",
" \n",
" def __init__(self, dir, bs=32):\n",
" self.dir = dir\n",
" self.bs = bs\n",
" self.path = self.dir / self.filename\n",
"\n",
" def prepare_data(self):\n",
" if not (self.path).exists():\n",
" content = requests.get(self.url + self.filename).content\n",
" self.path.open(\"wb\").write(content)\n",
"\n",
" def setup(self):\n",
" with gzip.open(self.path, \"rb\") as f:\n",
" ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n",
"\n",
" x_train, y_train, x_valid, y_valid = map(\n",
" torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
" )\n",
" \n",
" self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n",
" self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n",
" \n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x-8T_MlWifMe"
},
"source": [
"We'll cover `DataModule`s in more detail later.\n",
"\n",
"We can now incorporate our `DataModule`\n",
"into the fitting pipeline\n",
"by calling its methods as needed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mcFcbRhSJ3yN"
},
"outputs": [],
"source": [
"def fit(self: nn.Module, datamodule):\n",
" datamodule.prepare_data()\n",
" datamodule.setup()\n",
"\n",
" val_dataloader = datamodule.val_dataloader()\n",
" \n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(\"before start of training:\", valid_loss / len(val_dataloader))\n",
"\n",
" opt = configure_optimizer(self)\n",
" train_dataloader = datamodule.train_dataloader()\n",
" for epoch in range(epochs):\n",
" self.train()\n",
" for xb, yb in train_dataloader:\n",
" pred = self(xb)\n",
" loss = loss_func(pred, yb)\n",
"\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
"\n",
" self.eval()\n",
" with torch.no_grad():\n",
" valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n",
"\n",
" print(epoch, valid_loss / len(val_dataloader))\n",
"\n",
"\n",
"MNISTLogistic.fit = fit\n",
"MLP.fit = fit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Uqey9w6jkv9"
},
"source": [
"Now we've substantially cut down on the \"hidden state\" in our fitting code:\n",
"if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n",
"then you can train a network with just the cell below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uxN1yV6DX6Nz"
},
"outputs": [],
"source": [
"model = MLP(data_config)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=32)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2zHA12Iih0ML"
},
"source": [
"You may have noticed a few other changes in the `.fit` method:\n",
"\n",
"- `self.eval` vs `self.train`:\n",
"it's helpful to have features of neural networks that behave differently in `train`ing\n",
"than they do in production or `eval`uation.\n",
"[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n",
"and\n",
"[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n",
"are among the most popular examples.\n",
"We need to take this into account now that we\n",
"have a validation loop.\n",
"- The return of `torch.no_grad`: in our first few implementations,\n",
"we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n",
"Now, we need to use it to avoid tracking gradients during validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaODkqTnJ3yO"
},
"source": [
"This is starting to get a bit hairy again!\n",
"We're back up to about 30 lines of code,\n",
"right where we started\n",
"(but now with way more features!).\n",
"\n",
"Much like `torch.nn` provides useful tools and interfaces for\n",
"defining neural networks,\n",
"iterating over batches,\n",
"and calculating gradients,\n",
"frameworks on top of PyTorch, like\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n",
"provide useful tools and interfaces\n",
"for an even higher level of abstraction over neural network training.\n",
"\n",
"For serious deep learning codebases,\n",
"you'll want to use a framework at that level of abstraction --\n",
"either one of the popular open frameworks or one developed in-house.\n",
"\n",
"For most of these frameworks,\n",
"you'll still need facility with core PyTorch:\n",
"at least for defining models and\n",
"often for defining data pipelines as well."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-4piIilkyZpD"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E482VfIlyZpD"
},
"source": [
"### 🌟 Try out different hyperparameters for the `MLP` and for training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQ8bkAxNyZpD"
},
"source": [
"The `MLP` class is configured via the `args` argument to its constructor,\n",
"which can set the values of hyperparameters like the width of layers and the degree of dropout:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Tl-AvMVyZpD"
},
"outputs": [],
"source": [
"MLP.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HfbQ0KkyZpD"
},
"source": [
"As the type signature indicates, `args` is an `argparse.Namespace`.\n",
"[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n",
"and later on we'll see how to configure models\n",
"and launch training jobs from the command line\n",
"in the FSDL codebase.\n",
"\n",
"For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n",
"\n",
"Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n",
"\n",
"Can you get a final `valid`ation `acc`uracy of 98%?\n",
"Can you get to 95% 2x faster than the baseline `MLP`?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-vVtGJhtyZpD"
},
"outputs": [],
"source": [
"%%time \n",
"from argparse import Namespace # you'll need this\n",
"\n",
"args = None # edit this\n",
"\n",
"epochs = 2 # used in fit\n",
"bs = 32 # used by the DataModule\n",
"\n",
"\n",
"# used in fit, play around with this if you'd like\n",
"def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n",
" return optim.Adam(model.parameters(), lr=3e-4)\n",
"\n",
"\n",
"model = MLP(data_config, args=args)\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7yyxc3uxyZpD"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0ZHygZtgyZpE"
},
"source": [
"### 🌟🌟🌟 Write your own `nn.Module`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r3Iu73j3yZpE"
},
"source": [
"Designing new models is one of the most fun\n",
"aspects of building an ML-powered application.\n",
"\n",
"Can you make an `nn.Module` that looks different from\n",
"the standard `MLP` but still gets 98% validation accuracy or higher?\n",
"You might start from the `MLP` and\n",
"[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n",
"while adding more bells and whistles.\n",
"Take care to keep the shapes of the `Tensor`s aligned as you go.\n",
"\n",
"Here's some tricks you can try that are especially helpful with deeper networks:\n",
"- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n",
"layers, which can improve\n",
"[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n",
"- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n",
"- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n",
"like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n",
"or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n",
"\n",
"If you want to make an `nn.Module` that can have different depths,\n",
"check out the\n",
"[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JsF_RfrDyZpE"
},
"outputs": [],
"source": [
"class YourModel(nn.Module):\n",
" def __init__(self): # add args and kwargs here as you like\n",
" super().__init__()\n",
" # use those args and kwargs to set up the submodules\n",
" self.ps = nn.Parameter(torch.zeros(10))\n",
"\n",
" def forward(self, xb): # overwrite this to use your nn.Modules from above\n",
" xb = torch.stack([self.ps for ii in range(len(xb))])\n",
" return xb\n",
" \n",
" \n",
"YourModel.fit = fit # don't forget this!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t6OQidtGyZpE"
},
"outputs": [],
"source": [
"model = YourModel()\n",
"model.to(device)\n",
"\n",
"datamodule = MNISTDataModule(dir=path, bs=bs)\n",
"\n",
"model.fit(datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CH0U4ODoyZpE"
},
"outputs": [],
"source": [
"val_dataloader = datamodule.val_dataloader()\n",
"valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n",
"valid_acc"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab01_pytorch.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab08/notebooks/lab02a_lightning.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02a: PyTorch Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n",
"- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n",
"- How we use these features in the FSDL codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Lightning?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bP8iJW_bg7IC"
},
"source": [
"PyTorch is a powerful library for executing differentiable\n",
"tensor operations with hardware acceleration\n",
"and it includes many neural network primitives,\n",
"but it has no concept of \"training\".\n",
"At a high level, an `nn.Module` is a stateful function with gradients\n",
"and a `torch.optim.Optimizer` can update that state using gradients,\n",
"but there's no pre-built tools in PyTorch to iteratively generate those gradients from data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a7gIA-Efy91E"
},
"source": [
"So the first thing many folks do in PyTorch is write that code --\n",
"a \"training loop\" to iterate over their `DataLoader`,\n",
"which in pseudocode might look something like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y3ewkWrwzDA8"
},
"source": [
"```python\n",
"for batch in dataloader:\n",
" inputs, targets = batch\n",
"\n",
" outputs = model(inputs)\n",
" loss = some_loss_function(targets, outputs)\n",
" \n",
" optimizer.zero_gradients()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OYUtiJWize82"
},
"source": [
"This is a solid start, but other needs immediately arise.\n",
"You'll want to run your model on validation and test data,\n",
"which need their own `DataLoader`s.\n",
"Once finished, you'll want to save your model --\n",
"and for long-running jobs, you probably want\n",
"to save checkpoints of the training process\n",
"so that it can be resumed in case of a crash.\n",
"For state-of-the-art model performance in many domains,\n",
"you'll want to distribute your training across multiple nodes/machines\n",
"and across multiple GPUs within those nodes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0untumvjy5fm"
},
"source": [
"That's just the tip of the iceberg, and you want\n",
"all those features to work for lots of models and datasets,\n",
"not just the one you're writing now."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNPpi4OZjMbu"
},
"source": [
"You don't want to write all of this yourself.\n",
"\n",
"So unless you are at a large organization that has a dedicated team\n",
"for building that \"framework\" code,\n",
"you'll want to use an existing library."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tnQuyVqUjJy8"
},
"source": [
"PyTorch Lightning is a popular framework on top of PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ecipNFTgZDt"
},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"\n",
"version = pl.__version__\n",
"\n",
"docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n",
"docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bE82xoEikWkh"
},
"source": [
"At its core, PyTorch Lightning provides\n",
"\n",
"1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n",
"2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n",
"\n",
"Both of these are kitted out with all the features\n",
"a cutting-edge deep learning codebase needs:\n",
"- flags for switching device types and distributed computing strategy\n",
"- saving, checkpointing, and resumption\n",
"- calculation and logging of metrics\n",
"\n",
"and much more.\n",
"\n",
"Importantly these features can be easily\n",
"added, removed, extended, or bypassed\n",
"as desired, meaning your code isn't constrained by the framework."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uuJUDmCeT3RK"
},
"source": [
"In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n",
"as shown in the video below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wTt0TBs5TZpm"
},
"outputs": [],
"source": [
"import IPython.display as display\n",
"\n",
"\n",
"display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n",
" width=720, height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGwpDn5GWn_X"
},
"source": [
"That's opposed to the other way frameworks are designed,\n",
"to provide abstractions over the lower-level library\n",
"(here, PyTorch).\n",
"\n",
"Because of this \"organize don't abstract\" style,\n",
"writing PyTorch Lightning code involves\n",
"a lot of over-riding of methods --\n",
"you inherit from a class\n",
"and then implement the specific version of a general method\n",
"that you need for your code,\n",
"rather than Lightning providing a bunch of already\n",
"fully-defined classes that you just instantiate,\n",
"using arguments for configuration."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TXiUcQwan39S"
},
"source": [
"# The `pl.LightningModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3FffD5Vn6we"
},
"source": [
"The first of our two core classes,\n",
"the `LightningModule`,\n",
"is like a souped-up `torch.nn.Module` --\n",
"it inherits all of the `Module` features,\n",
"but adds more."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0QWwSStJTP28"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"issubclass(pl.LightningModule, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1wiBVSTuHNT"
},
"source": [
"To demonstrate how this class works,\n",
"we'll build up a `LinearRegression` model dynamically,\n",
"method by method.\n",
"\n",
"For this example we hard code lots of the details,\n",
"but the real benefit comes when the details are configurable.\n",
"\n",
"In order to have a realistic example as well,\n",
"we'll compare to the actual code\n",
"in the `BaseLitModel` we use in the codebase\n",
"as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fPARncfQ3ohz"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import BaseLitModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "myyL0vYU3z0a"
},
"source": [
"A `pl.LightningModule` is a `torch.nn.Module`,\n",
"so the basic definition looks the same:\n",
"we need `__init__` and `forward`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-c0ylFO9rW_t"
},
"outputs": [],
"source": [
"class LinearRegression(pl.LightningModule):\n",
"\n",
" def __init__(self):\n",
" super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n",
"\n",
" # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n",
" self.model = torch.nn.Linear(in_features=1, out_features=1)\n",
" # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n",
"\n",
" # optionally, define a forward method\n",
" def forward(self, xs):\n",
" return self.model(xs) # we like to just call the model's forward method"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZY1yoGTy6CBu"
},
"source": [
"But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n",
"\n",
"If we try to use the class above with the `Trainer`, we get an error:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tBWh_uHu5rmU"
},
"outputs": [],
"source": [
"import logging # import some stdlib components to control what's display\n",
"import textwrap\n",
"import traceback\n",
"\n",
"\n",
"try: # try using the LinearRegression LightningModule defined above\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n",
"\n",
" model = LinearRegression()\n",
"\n",
" # we'll explain how the Trainer works in a bit\n",
" trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n",
" trainer.fit(model=model) \n",
"\n",
"except pl.utilities.exceptions.MisconfigurationException as error:\n",
" print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n",
"\n",
"finally: # bring back info-level logging\n",
" logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s5ni7xe5CgUt"
},
"source": [
"The error message says we need some more methods.\n",
"\n",
"Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "37BXP7nAoBik"
},
"source": [
"#### `.training_step`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ah9MjWz2plFv"
},
"source": [
"The `training_step` method defines,\n",
"naturally enough,\n",
"what to do during a single step of training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "plWEvWG_zRia"
},
"source": [
"Roughly, it gets used like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9RbxZ4idy-C5"
},
"source": [
"```python\n",
"\n",
"# pseudocode modified from the Lightning documentation\n",
"\n",
"# put model in train mode\n",
"model.train()\n",
"\n",
"for batch in train_dataloader:\n",
" # run the train step\n",
" loss = training_step(batch)\n",
"\n",
" # clear gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # backprop\n",
" loss.backward()\n",
"\n",
" # update parameters\n",
" optimizer.step()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cemh_hGJ53nL"
},
"source": [
"Effectively, it maps a batch to a loss value,\n",
"so that PyTorch can backprop through that loss.\n",
"\n",
"The `.training_step` for our `LinearRegression` model is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X8qW2VRRsPI2"
},
"outputs": [],
"source": [
"from typing import Tuple\n",
"\n",
"\n",
"def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" xs, ys = batch # unpack the batch\n",
" outs = self(xs) # apply the model\n",
" loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n",
" return loss\n",
"\n",
"\n",
"LinearRegression.training_step = training_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x2e8m3BRCIx6"
},
"source": [
"If you've written PyTorch code before, you'll notice that we don't mention devices\n",
"or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FkvNpfwqpns5"
},
"source": [
"You can additionally define\n",
"a `validation_step` and a `test_step`\n",
"to define the model's behavior during\n",
"validation and testing loops.\n",
"\n",
"You're invited to define these steps\n",
"in the exercises at the end of the lab.\n",
"\n",
"Inside this step is also where you might calculate other\n",
"values related to inputs, outputs, and loss,\n",
"like non-differentiable metrics (e.g. accuracy, precision, recall).\n",
"\n",
"So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n",
"and the details of the forward pass are deferred to `._run_on_batch` instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xpBkRczao1hr"
},
"outputs": [],
"source": [
"BaseLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "guhoYf_NoEyc"
},
"source": [
"#### `.configure_optimizers`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCIAWoCEtIU7"
},
"source": [
"Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n",
"\n",
"But we need more than a gradient to do an update.\n",
"\n",
"We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n",
"\n",
"Our second required method, `.configure_optimizers`,\n",
"sets up the `torch.optim.Optimizer`s \n",
"(e.g. setting their hyperparameters\n",
"and pointing them at the `Module`'s parameters)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bMlnRdIPzvDF"
},
"source": [
"In psuedo-code (modified from the Lightning documentation), it gets used something like this:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_WBnfJzszi49"
},
"source": [
"```python\n",
"optimizer = model.configure_optimizers()\n",
"\n",
"for batch_idx, batch in enumerate(data):\n",
"\n",
" def closure(): # wrap the loss calculation\n",
" loss = model.training_step(batch, batch_idx, ...)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" return loss\n",
"\n",
" # optimizer can call the loss calculation as many times as it likes\n",
" optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGsP3DBy7YzW"
},
"source": [
"For our `LinearRegression` model,\n",
"we just need to instantiate an optimizer and point it at the parameters of the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZWrWGgdVt21h"
},
"outputs": [],
"source": [
"def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n",
" return optimizer\n",
"\n",
"\n",
"LinearRegression.configure_optimizers = configure_optimizers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ta2hs0OLwbtF"
},
"source": [
"You can read more about optimization in Lightning,\n",
"including how to manually control optimization\n",
"instead of relying on default behavior,\n",
"in the docs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KXINqlAgwfKy"
},
"outputs": [],
"source": [
"optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n",
"optimization_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zWdKdZDfxmb2"
},
"source": [
"The `configure_optimizers` method for the `BaseLitModel`\n",
"isn't that much more complex.\n",
"\n",
"We just add support for learning rate schedulers:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kyRbz0bEpWwd"
},
"outputs": [],
"source": [
"BaseLitModel.configure_optimizers??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ilQCfn7Nm_QP"
},
"source": [
"# The `pl.Trainer`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RScc0ef97qlc"
},
"source": [
"The `LightningModule` has already helped us organize our code,\n",
"but it's not really useful until we combine it with the `Trainer`,\n",
"which relies on the `LightningModule` interface to execute training, validation, and testing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bBdikPBF86Qp"
},
"source": [
"The `Trainer` is where we make choices like how long to train\n",
"(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n",
"what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n",
"and other settings that might differ across training runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YQ4KSdFP3E4Q"
},
"outputs": [],
"source": [
"trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S2l3rGZK7-PL"
},
"source": [
"Before we can actually use the `Trainer`, though,\n",
"we also need a `torch.utils.data.DataLoader` --\n",
"nothing new from PyTorch Lightning here,\n",
"just vanilla PyTorch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OcUSD2jP4Ffo"
},
"outputs": [],
"source": [
"class CorrelatedDataset(torch.utils.data.Dataset):\n",
"\n",
" def __init__(self, N=10_000):\n",
" self.N = N\n",
" self.xs = torch.randn(size=(N, 1))\n",
" self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n",
"\n",
" def __getitem__(self, idx):\n",
" return (self.xs[idx], self.ys[idx])\n",
"\n",
" def __len__(self):\n",
" return self.N\n",
"\n",
"\n",
"dataset = CorrelatedDataset()\n",
"tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o0u41JtA8qGo"
},
"source": [
"We can fetch some sample data from the `DataLoader`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z1j6Gj9Ka0dJ"
},
"outputs": [],
"source": [
"example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n",
"\n",
"print(\"xs:\", example_xs[:10], sep=\"\\n\")\n",
"print(\"ys:\", example_ys[:10], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Nnqk3mRv8dbW"
},
"source": [
"and, since it's low-dimensional, visualize it\n",
"and see what we're asking the model to learn:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "33jcHbErbl6Q"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", kind=\"scatter\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pA7-4tJJ9fde"
},
"source": [
"Now we're ready to run training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IY910O803oPU"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer.fit(model=model, train_dataloaders=tdl)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQBXYmLF_GoI"
},
"source": [
"The loss after training should be less than the loss before training,\n",
"and we can see that our model's predictions line up with the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqcbA91x96-s"
},
"outputs": [],
"source": [
"ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n",
" .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n",
"\n",
"inps = torch.arange(-2, 2, 0.5)[:, None]\n",
"ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZkpsNfl3P8R"
},
"source": [
"The `Trainer` promises to \"customize every aspect of training via flags\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_Q-c9b62_XFj"
},
"outputs": [],
"source": [
"pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "He-zEwMB_oKH"
},
"source": [
"and they mean _every_ aspect.\n",
"\n",
"The cell below prints all of the arguments for the `pl.Trainer` class --\n",
"no need to memorize or even understand them all now,\n",
"just skim it to see how many customization options there are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8F_rRPL3lfPE"
},
"outputs": [],
"source": [
"print(pl.Trainer.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4X8dGmR53kYU"
},
"source": [
"It's probably easier to read them on the documentation website:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cqUj6MxRkppr"
},
"outputs": [],
"source": [
"trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n",
"trainer_docs_link"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3T8XMYvr__Y5"
},
"source": [
"# Training with PyTorch Lightning in the FSDL Codebase"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CtaPliTAxy3"
},
"source": [
"The `LightningModule`s in the FSDL codebase\n",
"are stored in the `lit_models` submodule of the `text_recognizer` module.\n",
"\n",
"For now, we've just got some basic models.\n",
"We'll add more as we go."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NMe5z1RSAyo_"
},
"outputs": [],
"source": [
"!ls text_recognizer/lit_models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZTYmIHbBu7g"
},
"source": [
"We also have a folder called `training` now.\n",
"\n",
"This contains a script, `run_experiment.py`,\n",
"that is used for running training jobs.\n",
"\n",
"In case you want to play around with the training code\n",
"in a notebook, you can also load it as a module:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DRz9GbXzNJLM"
},
"outputs": [],
"source": [
"!ls training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Im9vLeyqBv_h"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u2hcAXqHAV0v"
},
"source": [
"We build the `Trainer` from command line arguments:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yi50CDZul7Mm"
},
"outputs": [],
"source": [
"# how the trainer is initialized in the training script\n",
"!grep \"pl.Trainer.from\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bZQheYJyAxlh"
},
"source": [
"so all the configuration flexibility and complexity of the `Trainer`\n",
"is available via the command line.\n",
"\n",
"Docs for the command line arguments for the trainer are accessible with `--help`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XlSmSyCMAw7Z"
},
"outputs": [],
"source": [
"# displays the first few flags for controlling the Trainer from the command line\n",
"!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mIZ_VRPcNMsM"
},
"source": [
"We'll use `run_experiment` in\n",
"[Lab 02b](http://fsdl.me/lab02b-colab)\n",
"to train convolutional neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z0siaL4Qumc_"
},
"source": [
"# Extra Goodies"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PkQSPnxQDBF6"
},
"source": [
"The `LightningModule` and the `Trainer` are the minimum amount you need\n",
"to get started with PyTorch Lightning.\n",
"\n",
"But they aren't all you need.\n",
"\n",
"There are many more features built into Lightning and its ecosystem.\n",
"\n",
"We'll cover three more here:\n",
"- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n",
"- `pl.Callback`s, for adding \"optional\" extra features to model training\n",
"- `torchmetrics`, for efficiently computing and logging "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GOYHSLw_D8Zy"
},
"source": [
"## `pl.LightningDataModule`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpjTNGzREIpl"
},
"source": [
"Where the `LightningModule` organizes our model and its optimizers,\n",
"the `LightningDataModule` organizes our dataloading code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_KkQ0iOWKD7"
},
"source": [
"The class-level docstring explains the concept\n",
"behind the class well\n",
"and lists the main methods to be over-ridden:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFTWHdsFV5WG"
},
"outputs": [],
"source": [
"print(pl.LightningDataModule.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rLiacppGB9BB"
},
"source": [
"Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "m1d62iC6Xv1i"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"\n",
"class CorrelatedDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n",
" super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n",
"\n",
" # set some constants, like the train/val split\n",
" self.size = size\n",
" self.train_frac, self.val_frac = train_frac, 1 - train_frac\n",
" self.train_indices = list(range(math.floor(self.size * train_frac)))\n",
" self.val_indices = list(range(self.train_indices[-1], self.size))\n",
"\n",
" # under the hood, we've still got a torch Dataset\n",
" self.dataset = CorrelatedDataset(N=size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qQf-jUYRCi3m"
},
"source": [
"`LightningDataModule`s are designed to work in distributed settings,\n",
"where operations that set state\n",
"(e.g. writing to disk or attaching something to `self` that you want to access later)\n",
"need to be handled with care.\n",
"\n",
"Getting data ready for training is often a very stateful operation,\n",
"so the `LightningDataModule` provides two separate methods for it:\n",
"one called `setup` that handles any state that needs to be set up in each copy of the module\n",
"(here, splitting the data and adding it to `self`)\n",
"and one called `prepare_data` that handles any state that only needs to be set up in each machine\n",
"(for example, downloading data from storage and writing it to the local disk)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mttu--rHX70r"
},
"outputs": [],
"source": [
"def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n",
" if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n",
" self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n",
" self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n",
"\n",
"def prepare_data(self): # prepares state that needs to be set once per node\n",
" pass # but we don't have any \"node-level\" computations\n",
"\n",
"\n",
"CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rh3mZrjwD83Y"
},
"source": [
"We then define methods to return `DataLoader`s when requested by the `Trainer`.\n",
"\n",
"To run a testing loop that uses a `LightningDataModule`,\n",
"you'll also need to define a `test_dataloader`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xu9Ma3iKYPBd"
},
"outputs": [],
"source": [
"def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n",
"\n",
"def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n",
"\n",
"CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNodiN6oawX5"
},
"source": [
"Now we're ready to run training using a datamodule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JKBwoE-Rajqw"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"trainer.fit(model=model, datamodule=datamodule)\n",
"\n",
"print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bw6flh5Jf2ZP"
},
"source": [
"Notice the warning: \"`Skipping val loop.`\"\n",
"\n",
"It's being raised because our minimal `LinearRegression` model\n",
"doesn't have a `.validation_step` method.\n",
"\n",
"In the exercises, you're invited to add a validation step and resolve this warning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rJnoFx47ZjBw"
},
"source": [
"In the FSDL codebase,\n",
"we define the basic functions of a `LightningDataModule`\n",
"in the `BaseDataModule` and defer details to subclasses:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PTPKvDDGXmOr"
},
"outputs": [],
"source": [
"from text_recognizer.data import BaseDataModule\n",
"\n",
"\n",
"BaseDataModule??"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3mRlZecwaKB4"
},
"outputs": [],
"source": [
"from text_recognizer.data.mnist import MNIST\n",
"\n",
"\n",
"MNIST??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uQbMY08qD-hm"
},
"source": [
"## `pl.Callback`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVe7TSNvHK4K"
},
"source": [
"Lightning's `Callback` class is used to add \"nice-to-have\" features\n",
"to training, validation, and testing\n",
"that aren't strictly necessary for any model to run\n",
"but are useful for many models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzU76wgFGw9N"
},
"source": [
"A \"callback\" is a unit of code that's meant to be called later,\n",
"based on some trigger.\n",
"\n",
"It's a very flexible system, which is why\n",
"`Callback`s are used internally to implement lots of important Lightning features,\n",
"including some we've already discussed, like `ModelCheckpoint` for saving during training:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-msDjbKdHTxU"
},
"outputs": [],
"source": [
"pl.callbacks.__all__ # builtin Callbacks from Lightning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6WRNXtHHkbM"
},
"source": [
"The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n",
"\n",
"The names of the hooks generally explain when the hook will be called,\n",
"but you can always check the documentation for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3iHjjnU8Hvgg"
},
"outputs": [],
"source": [
"hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n",
"print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2E2M7O2cGdj7"
},
"source": [
"You can define your own `Callback` by inheriting from `pl.Callback`\n",
"and over-riding one of the \"hook\" methods --\n",
"much the same way that you define your own `LightningModule`\n",
"by writing your own `.training_step` and `.configure_optimizers`.\n",
"\n",
"Let's define a silly `Callback` just to demonstrate the idea:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UodFQKAGEJlk"
},
"outputs": [],
"source": [
"class HelloWorldCallback(pl.Callback):\n",
"\n",
" def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the start of the training epoch!\")\n",
"\n",
" def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n",
" print(\"👋 hello from the end of the validation epoch!\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MU7oIpyEGoaP"
},
"source": [
"This callback will print a message whenever the training epoch starts\n",
"and whenever the validation epoch ends.\n",
"\n",
"Different \"hooks\" have different information directly available.\n",
"\n",
"For example, you can directly access the batch information\n",
"inside the `on_train_batch_start` and `on_train_batch_end` hooks:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U17Qo_i_GCya"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"\n",
"def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n",
" if random.random() > 0.995:\n",
" print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n",
"\n",
"\n",
"HelloWorldCallback.on_train_batch_start = on_train_batch_start"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LVKQXZOwQNGJ"
},
"source": [
"We provide the callbacks when initializing the `Trainer`,\n",
"then they are invoked during model fitting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-XHXZ64-ETCz"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n",
" max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UEHUUhVOQv6K"
},
"outputs": [],
"source": [
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pP2Xj1woFGwG"
},
"source": [
"You can read more about callbacks in the documentation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "COHk5BZvFJN_"
},
"outputs": [],
"source": [
"callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n",
"callback_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y2K9e44iEGCR"
},
"source": [
"## `torchmetrics`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dO-UIFKyJCqJ"
},
"source": [
"DNNs are also finicky and break silently:\n",
"rather than crashing, they just start doing the wrong thing.\n",
"Without careful monitoring, that wrong thing can be invisible\n",
"until long after it has done a lot of damage to you, your team, or your users.\n",
"\n",
"We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n",
"or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n",
"meaning we can also determine\n",
"how to fix bugs in training just by viewing logs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4YMyUI0Jr2f"
},
"source": [
"But DNN training is also performance sensitive.\n",
"Training runs for large language models have budgets that are\n",
"more comparable to building an apartment complex\n",
"than they are to the build jobs of traditional software pipelines.\n",
"\n",
"Slowing down training even a small amount can add a substantial dollar cost,\n",
"obviating the benefits of catching and fixing bugs more quickly.\n",
"\n",
"Also implementing metric calculation during training adds extra work,\n",
"much like the other software engineering best practices which it closely resembles,\n",
"namely test-writing and monitoring.\n",
"This distracts and detracts from higher-leverage research work."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sbvWjiHSIxzM"
},
"source": [
"\n",
"The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n",
"resolves these issues by providing a `Metric` class that\n",
"incorporates best performance practices,\n",
"like smart accumulation across batches and over devices,\n",
"defines a unified interface,\n",
"and integrates with Lightning's built-in logging."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21y3lgvwEKPC"
},
"outputs": [],
"source": [
"import torchmetrics\n",
"\n",
"\n",
"tm_version = torchmetrics.__version__\n",
"print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9TuPZkV1gfFE"
},
"source": [
"Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n",
"\n",
"That's because metric calculation, like module application, is typically\n",
"1) an array-heavy computation that\n",
"2) relies on persistent state\n",
"(parameters for `Module`s, running values for `Metric`s) and\n",
"3) benefits from acceleration and\n",
"4) can be distributed over devices and nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "leiiI_QDS2_V"
},
"outputs": [],
"source": [
"issubclass(torchmetrics.Metric, torch.nn.Module)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wy8MF2taP8MV"
},
"source": [
"Documentation for the version of `torchmetrics` we're using can be found here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LN4ashooP_tM"
},
"outputs": [],
"source": [
"torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n",
"torchmetrics_docs_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5aycHhZNXwjr"
},
"source": [
"In the `BaseLitModel`,\n",
"we use the `torchmetrics.Accuracy` metric:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Vyq4IjmBXzTv"
},
"outputs": [],
"source": [
"BaseLitModel.__init__??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KPoTH50YfkMF"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hD_6PVAeflWw"
},
"source": [
"### 🌟 Add a `validation_step` to the `LinearRegression` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5KKbAN9eK281"
},
"outputs": [],
"source": [
"def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"\n",
"LinearRegression.validation_step = validation_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AnPPHAPxFCEv"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"# if you code is working, you should see results for the validation loss in the output\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u42zXktOFDhZ"
},
"source": [
"### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cbWfqvumFESV"
},
"outputs": [],
"source": [
"def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n",
" pass # your code here\n",
"\n",
"LinearRegression.test_step = test_step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pB96MpibLeJi"
},
"outputs": [],
"source": [
"class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n",
"\n",
" def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n",
" super().__init__() # don't forget this!\n",
" self.dataset = None\n",
" self.test_dataset = None # define a test set -- another sample from the same distribution\n",
"\n",
" def setup(self, stage=None):\n",
" pass\n",
"\n",
" def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n",
" pass # create a dataloader for the test set here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1jq3dcugMMOu"
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModuleWithTest()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# we run testing without fitting here\n",
"trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JHg4MKmJPla6"
},
"source": [
"### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M_1AKGWRR2ai"
},
"source": [
"The \"variance explained\" is a useful metric for comparing regression models --\n",
"its values are interpretable and comparable across datasets, unlike raw loss values.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vLecK4CsQWKk"
},
"source": [
"Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n",
"add metrics and metric logging\n",
"to a `LightningModule`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cWy0HyG4RYnX"
},
"outputs": [],
"source": [
"torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n",
"torchmetrics_guide_url"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UoSQ3y6sSTvP"
},
"source": [
"And check out the docs for `ExplainedVariance` to see how it's calculated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GpGuRK2FRHh1"
},
"outputs": [],
"source": [
"print(torchmetrics.ExplainedVariance.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_EAtpWXrSVR1"
},
"source": [
"You'll want to start the `LinearRegression` class over from scratch,\n",
"since the `__init__` and `{training, validation, test}_step` methods need to be rewritten."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rGtWt3_5SYTn"
},
"outputs": [],
"source": [
"# your code here"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFWNr1SfS5-r"
},
"source": [
"You can test your code by running fitting and testing.\n",
"\n",
"To see whether it's working,\n",
"[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n",
"with the\n",
"[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n",
"You should see the explained variance show up in the output alongside the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jse95DGCS6gR",
"scrolled": false
},
"outputs": [],
"source": [
"model = LinearRegression()\n",
"datamodule = CorrelatedDataModule()\n",
"\n",
"dataset = datamodule.dataset\n",
"\n",
"trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n",
"\n",
"# if your code is working, you should see explained variance in the progress bar/logs\n",
"trainer.fit(model=model, datamodule=datamodule)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab02a_lightning.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab08/notebooks/lab02b_cnn.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 02b: Training a CNN on Synthetic Handwriting Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Fundamental principles for building neural networks with convolutional components\n",
"- How to use Lightning's training framework via a CLI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 2\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9HoYWZKtTE_"
},
"source": [
"The most basic neural networks,\n",
"multi-layer perceptrons,\n",
"are built by alternating\n",
"parameterized linear transformations\n",
"with non-linear transformations.\n",
"\n",
"This combination is capable of expressing\n",
"[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n",
"so long as those functions\n",
"take in fixed-size arrays and return fixed-size arrays.\n",
"\n",
"```python\n",
"def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n",
" return some_mlp_that_might_be_impractically_huge(x)\n",
"```\n",
"\n",
"But not all functions have that type signature.\n",
"\n",
"For example, we might want to identify the content of images\n",
"that have different sizes.\n",
"Without gross hacks,\n",
"an MLP won't be able to solve this problem,\n",
"even though it seems simple enough."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6LjfV3o6tTFA"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import IPython.display as display\n",
"\n",
"randsize = 10 ** (random.random() * 2 + 1)\n",
"\n",
"Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n",
"\n",
"# run multiple times to display the same image at different sizes\n",
"# the content of the image remains unambiguous\n",
"display.Image(url=Url, width=randsize, height=randsize)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9j6YQRftTFB"
},
"source": [
"Even worse, MLPs are too general to be efficient.\n",
"\n",
"Each layer applies an unstructured matrix to its inputs.\n",
"But most of the data we might want to apply them to is highly structured,\n",
"and taking advantage of that structure can make our models more efficient.\n",
"\n",
"It may seem appealing to use an unstructured model:\n",
"it can in principle learn any function.\n",
"But\n",
"[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n",
"It is useful to encode some of our assumptions\n",
"about the kinds of functions we might want to learn\n",
"from our data into our model's architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jvC_yZvmuwgJ"
},
"source": [
"## Convolutions are the local, translation-equivariant linear transforms."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PhnRx_BZtTFC"
},
"source": [
"One of the most common types of structure in data is \"locality\" --\n",
"the most relevant information for understanding or predicting a pixel\n",
"is a small number of pixels around it.\n",
"\n",
"Locality is a fundamental feature of the physical world,\n",
"so it shows up in data drawn from physical observations,\n",
"like photographs and audio recordings.\n",
"\n",
"Locality means most meaningful linear transformations of our input\n",
"only have large weights in a small number of entries that are close to one another,\n",
"rather than having equally large weights in all entries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SSnkzV2_tTFC"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"generic_linear_transform = torch.randn(8, 1)\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"local_linear_transform = torch.tensor([\n",
" [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n",
"print(\"local:\", local_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0nCD75NwtTFD"
},
"source": [
"Another type of structure commonly observed is \"translation equivariance\" --\n",
"the top-left pixel position is not, in itself, meaningfully different\n",
"from the bottom-right position\n",
"or a position in the middle of the image.\n",
"Relative relationships matter more than absolute relationships.\n",
"\n",
"Translation equivariance arises in images because there is generally no privileged\n",
"vantage point for taking the image.\n",
"We could just as easily have taken the image while standing a few feet to the left or right,\n",
"and all of its contents would shift along with our change in perspective.\n",
"\n",
"Translation equivariance means that a linear transformation that is meaningful at one position\n",
"in our input is likely to be meaningful at all other points.\n",
"We can learn something about a linear transformation from a datapoint where it is useful\n",
"in the bottom-left and then apply it to another datapoint where it's useful in the top-right."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "srvI7JFAtTFE"
},
"outputs": [],
"source": [
"generic_linear_transform = torch.arange(8)[:, None]\n",
"print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n",
"\n",
"equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n",
"print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qF576NCvtTFE"
},
"source": [
"A linear transformation that is translation equivariant\n",
"[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n",
"\n",
"If the weights of that linear transformation are mostly zero\n",
"except for a few that are close to one another,\n",
"that convolution is said to have a _kernel_."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9tp4tBgWtTFF"
},
"outputs": [],
"source": [
"# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n",
"conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n",
"\n",
"conv_layer.weight # aka kernel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deXA_xS6tTFF"
},
"source": [
"Instead of using normal matrix multiplication to apply the kernel to the input,\n",
"we repeatedly apply that kernel over and over again,\n",
"\"sliding\" it over the input to produce an output.\n",
"\n",
"Every convolution kernel has an equivalent matrix form,\n",
"which can be matrix multiplied with the input to create the output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mFoSsa5DtTFF"
},
"outputs": [],
"source": [
"conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n",
"conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n",
"print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VJyRtf9NtTFG"
},
"source": [
"> Under the hood, the actual operation that implements the application of a convolutional kernel\n",
"need not look like either of these\n",
"(common approaches include\n",
"[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n",
"and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xytivdcItTFG"
},
"source": [
"Though they may seem somewhat arbitrary and technical,\n",
"convolutions are actually a deep and fundamental piece of mathematics and computer science.\n",
"Fundamental as in\n",
"[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n",
"and deep as in\n",
"[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n",
"Generalized convolutions can show up\n",
"wherever there is some kind of \"sum\" over some kind of \"paths\",\n",
"as is common in dynamic programming.\n",
"\n",
"In the context of this course,\n",
"we don't have time to dive much deeper on convolutions or convolutional neural networks.\n",
"\n",
"See Chris Olah's blog series\n",
"([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n",
"[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n",
"[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n",
"for a friendly introduction to the mathematical view of convolution.\n",
"\n",
"For more on convolutional neural network architectures, see\n",
"[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCJTwCWYzRee"
},
"source": [
"## We apply two-dimensional convolutions to images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8RKOPAIx0O2"
},
"source": [
"In building our text recognizer,\n",
"we're working with images.\n",
"Images have two dimensions of translation equivariance:\n",
"left/right and up/down.\n",
"So we use two-dimensional convolutions,\n",
"instantiated in `torch.nn` as `nn.Conv2d` layers.\n",
"Note that convolutional neural networks for images\n",
"are so popular that when the term \"convolution\"\n",
"is used without qualifier in a neural network context,\n",
"it can be taken to mean two-dimensional convolutions.\n",
"\n",
"Where `Linear` layers took in batches of vectors of a fixed size\n",
"and returned batches of vectors of a fixed size,\n",
"`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n",
"and return batches of two-dimensional stacked feature maps.\n",
"\n",
"A pseudocode type signature based on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n",
"might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sJvMdHL7w_lu"
},
"source": [
"```python\n",
"StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n",
"StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n",
"def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSMC8Fw3zPSz"
},
"source": [
"Here, \"map\" is meant to evoke space:\n",
"our feature maps tell us where\n",
"features are spatially located.\n",
"\n",
"An RGB image is a stacked feature map.\n",
"It is composed of three feature maps.\n",
"The first tells us where the \"red\" feature is present,\n",
"the second \"green\", the third \"blue\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jIXT-mym3ljt"
},
"outputs": [],
"source": [
"display.Image(\n",
" url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8WfCcO5xJ-hG"
},
"source": [
"When we apply a convolutional layer to a stacked feature map with some number of channels,\n",
"we get back a stacked feature map with some number of channels.\n",
"\n",
"This output is also a stack of feature maps,\n",
"and so it is a perfectly acceptable\n",
"input to another convolutional layer.\n",
"That means we can compose convolutional layers together,\n",
"just as we composed generic linear layers together.\n",
"We again weave non-linear functions in between our linear convolutions,\n",
"creating a _convolutional neural network_, or CNN."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R18TsGubJ_my"
},
"source": [
"## Convolutional neural networks build up visual understanding layer by layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eV03KmYBz2QM"
},
"source": [
"What is the equivalent of the labels, red/green/blue,\n",
"for the channels in these feature maps?\n",
"What does a high activation in some position in channel 32\n",
"of the fifteenth layer of my network tell me?\n",
"\n",
"There is no guaranteed way to automatically determine the answer,\n",
"nor is there a guarantee that the result is human-interpretable.\n",
"OpenAI's Clarity team spent several years \"reverse engineering\"\n",
"state-of-the-art convolutiuonal neural networks trained on photographs\n",
"and found that many of these channels are\n",
"[directly interpretable](https://distill.pub/2018/building-blocks/).\n",
"\n",
"For example, they found that if they pass an image through\n",
"[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n",
"aka InceptionV1,\n",
"the winner of the\n",
"[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/),"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "64KJR70q6dCh"
},
"outputs": [],
"source": [
"# a sample image\n",
"display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJ7CvvG78CZ5"
},
"source": [
"the features become increasingly complex,\n",
"with channels in early layers (left)\n",
"acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n",
"and channels in later layers (to right)\n",
"acting as feature maps for increasingly abstract concepts,\n",
"like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6w5_RR8d9jEY"
},
"outputs": [],
"source": [
"# from https://distill.pub/2018/building-blocks/\n",
"display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HLiqEwMY_Co0"
},
"source": [
"> The small square images depict a heuristic estimate\n",
"of what the entire collection of feature maps\n",
"at a given layer represent (layer IDs at bottom).\n",
"They are arranged in a spatial grid and their sizes represent\n",
"the total magnitude of the layer's activations at that position.\n",
"For details and interactivity, see\n",
"[the original Distill article](https://distill.pub/2018/building-blocks/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vl8XlEsaA54W"
},
"source": [
"In the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"blogpost series,\n",
"the Open AI Clarity team\n",
"combines careful examination of weights\n",
"with direct experimentation\n",
"to build an understanding of how these higher-level features\n",
"are constructed in GoogLeNet.\n",
"\n",
"For example,\n",
"they are able to provide reasonable interpretations for\n",
"[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"The cell below will pull down their \"weight explorer\"\n",
"and embed it in this notebook.\n",
"By default, it starts on\n",
"[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n",
"which constructs a large, phase-invariant\n",
"[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n",
"from smaller, phase-sensitive filters.\n",
"It is in turn used to construct\n",
"[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n",
"and\n",
"[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n",
"detectors --\n",
"click on any image to navigate to the weight explorer page\n",
"for that channel\n",
"or change the `layer` and `idx`\n",
"arguments.\n",
"For additional context,\n",
"check out the\n",
"[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n",
"\n",
"Click the \"View this neuron in the OpenAI Microscope\" link\n",
"for an even richer interactive view,\n",
"including activations on sample images\n",
"([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n",
"\n",
"The\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"which this explorer accompanies\n",
"is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n",
"that are invaluable for developing intuition about both\n",
"convolutional networks in particular and visual perception in general."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I4-hkYjdB-qQ"
},
"outputs": [],
"source": [
"layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n",
"layer = layers[1]\n",
"idx = 52\n",
"\n",
"weight_explorer = display.IFrame(\n",
" src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n",
"weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n",
"weight_explorer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJ6_PCmVtTFH"
},
"source": [
"# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N--VkRtR5Yr-"
},
"source": [
"If we load up the `CNN` class from `text_recognizer.models`,\n",
"we'll see that a `data_config` is required to instantiate the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N3MA--zytTFH"
},
"outputs": [],
"source": [
"import text_recognizer.models\n",
"\n",
"\n",
"text_recognizer.models.CNN??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7yCP46PO6XDg"
},
"source": [
"So before we can make our convolutional network and train it,\n",
"we'll need to get a hold of some data.\n",
"This isn't a general constraint by the way --\n",
"it's an implementation detail of the `text_recognizer` library.\n",
"But datasets and models are generally coupled,\n",
"so it's common for them to share configuration information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Z42K-jjtTFH"
},
"source": [
"## The `EMNIST` Handwritten Character Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oiifKuu4tTFH"
},
"source": [
"We could just use `MNIST` here,\n",
"as we did in\n",
"[the first lab](https://fsdl.me/lab01-colab).\n",
"\n",
"But we're aiming to eventually build a handwritten text recognition system,\n",
"which means we need to handle letters and punctuation,\n",
"not just numbers.\n",
"\n",
"So we instead use _EMNIST_,\n",
"or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n",
"which includes letters and punctuation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3ePZW1Tfa00K"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist = text_recognizer.data.EMNIST() # configure\n",
"print(emnist.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D_yjBYhla6qp"
},
"source": [
"We've built a PyTorch Lightning `DataModule`\n",
"to encapsulate all the code needed to get this dataset ready to go:\n",
"downloading to disk,\n",
"[reformatting to make loading faster](https://www.h5py.org/),\n",
"and splitting into training, validation, and test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty2vakBBtTFI"
},
"outputs": [],
"source": [
"emnist.prepare_data() # download, save to disk\n",
"emnist.setup() # create torch.utils.data.Datasets, do train/val split"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5h9bAXcu8l5J"
},
"source": [
"A brief aside: you might be wondering where this data goes.\n",
"Datasets are saved to disk inside the repo folder,\n",
"but not tracked in version control.\n",
"`git` works well for versioning source code\n",
"and other text files, but it's a poor fit for large binary data.\n",
"We only track and version metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E5cwDCM88SnU"
},
"outputs": [],
"source": [
"!echo {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname()}\n",
"!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IdsIBL9MtTFI"
},
"source": [
"This class comes with a pretty printing method\n",
"for quick examination of some of that metadata and basic descriptive statistics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Cyw66d6GtTFI"
},
"outputs": [],
"source": [
"emnist"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QT0burlOLgoH"
},
"source": [
"\n",
"> You can add pretty printing to your own Python classes by writing\n",
"`__str__` or `__repr__` methods for them.\n",
"The former is generally expected to be human-readable,\n",
"while the latter is generally expected to be machine-readable;\n",
"we've broken with that custom here and used `__repr__`. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJF3G5idtTFI"
},
"source": [
"Because we've run `.prepare_data` and `.setup`,\n",
"we can expect that this `DataModule` is ready to provide a `DataLoader`\n",
"if we invoke the right method --\n",
"sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n",
"even when we're not using the `Trainer` class itself,\n",
"[as described in Lab 2a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XJghcZkWtTFI"
},
"outputs": [],
"source": [
"xs, ys = next(iter(emnist.train_dataloader()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "40FWjMT-tTFJ"
},
"source": [
"Run the cell below to inspect random elements of this batch."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hywyEI_tTFJ"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"idx = random.randint(0, len(xs) - 1)\n",
"\n",
"print(emnist.mapping[ys[idx]])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdg_wYWntTFJ"
},
"source": [
"## Putting convolutions in a `torch.nn.Module`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGuSx_zvtTFJ"
},
"source": [
"Because we have the data,\n",
"we now have a `data_config`\n",
"and can instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rxLf7-5jtTFJ"
},
"outputs": [],
"source": [
"data_config = emnist.config()\n",
"\n",
"cnn = text_recognizer.models.CNN(data_config)\n",
"cnn # reveals the nn.Modules attached to our nn.Module"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jkeJNVnIMVzJ"
},
"source": [
"We can run this network on our inputs,\n",
"but we don't expect it to produce correct outputs without training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4EwujOGqMAZY"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = cnn(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3L8u0estTFJ"
},
"source": [
"We can inspect the `.forward` method to see how these `nn.Module`s are used.\n",
"\n",
"> Note: we encourage you to read through the code --\n",
"either inside the notebooks, as below,\n",
"in your favorite text editor locally, or\n",
"[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n",
"There's lots of useful bits of Python that we don't have time to cover explicitly in the labs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RtA0W8jvtTFJ"
},
"outputs": [],
"source": [
"cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VCycQ88gtTFK"
},
"source": [
"We apply convolutions followed by non-linearities,\n",
"with intermittent \"pooling\" layers that apply downsampling --\n",
"similar to the 1989\n",
"[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n",
"architecture or the 2012\n",
"[AlexNet](https://doi.org/10.1145%2F3065386)\n",
"architecture."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qkGJCnMttTFK"
},
"source": [
"The final classification is performed by an MLP.\n",
"\n",
"In order to get vectors to pass into that MLP,\n",
"we first apply `torch.flatten`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WZPhw7ufAKZ7"
},
"outputs": [],
"source": [
"torch.flatten(torch.Tensor([[1, 2], [3, 4]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jCoCa3vCNM8j"
},
"source": [
"## Design considerations for CNNs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dDLEMnPINTj7"
},
"source": [
"Since the release of AlexNet,\n",
"there has been a feverish decade of engineering and innovation in CNNs --\n",
"[dilated convolutions](https://arxiv.org/abs/1511.07122),\n",
"[residual connections](https://arxiv.org/abs/1512.03385), and\n",
"[batch normalization](https://arxiv.org/abs/1502.03167)\n",
"came out in 2015 alone, and\n",
"[work continues](https://arxiv.org/abs/2201.03545) --\n",
"so we can only scratch the surface in this course and\n",
"[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n",
"\n",
"The progress of DNNs in general and CNNs in particular\n",
"has been mostly evolutionary,\n",
"with lots of good ideas that didn't work out\n",
"and weird hacks that stuck around because they did.\n",
"That can make it very hard to design a fresh architecture\n",
"from first principles that's anywhere near as effective as existing architectures.\n",
"You're better off tweaking and mutating an existing architecture\n",
"than trying to design one yourself.\n",
"\n",
"If you're not keeping close tabs on the field,\n",
"when your first start looking for an architecture to base your work off of\n",
"it's best to go to trusted aggregators, like\n",
"[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n",
"or `timm`, on GitHub, or\n",
"[Papers With Code](https://paperswithcode.com),\n",
"specifically the section for\n",
"[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n",
"You can also take a more bottom-up approach by checking\n",
"the leaderboards of the latest\n",
"[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n",
"\n",
"We'll briefly touch here on some of the main design considerations\n",
"with classic CNN architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nd0OeyouDNlS"
},
"source": [
"### Shapes and padding"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5w3p8QP6AnGQ"
},
"source": [
"In the `.forward` pass of the `CNN`,\n",
"we've included comments that indicate the expected shapes\n",
"of tensors after each line that changes the shape.\n",
"\n",
"Tracking and correctly handling shapes is one of the bugbears\n",
"of CNNs, especially architectures,\n",
"like LeNet/AlexNet, that include MLP components\n",
"that can only operate on fixed-shape tensors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vgbM30jstTFK"
},
"source": [
"[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n",
"if you're supporting the wide variety of convolutions.\n",
"\n",
"The easiest way to avoid shape bugs is to keep things simple:\n",
"choose your convolution parameters,\n",
"like `padding` and `stride`,\n",
"to keep the shape the same before and after\n",
"the convolution.\n",
"\n",
"That's what we do, by choosing `padding=1`\n",
"for `kernel_size=3` and `stride=1`.\n",
"With unit strides and odd-numbered kernel size,\n",
"the padding that keeps\n",
"the input the same size is `kernel_size // 2`.\n",
"\n",
"As shapes change, so does the amount of GPU memory taken up by the tensors.\n",
"Keeping sizes fixed within a block removes one axis of variation\n",
"in the demands on an important resource.\n",
"\n",
"After applying our pooling layer,\n",
"we can just increase the number of kernels by the right factor\n",
"to keep total tensor size,\n",
"and thus memory footprint, constant."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2BCkTZGSDSBG"
},
"source": [
"### Parameters, computation, and bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pZbgm7wztTFK"
},
"source": [
"If we review the `num`ber of `el`ements in each of the layers,\n",
"we see that one layer has far more entries than all the others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8nfjPVwztTFK"
},
"outputs": [],
"source": [
"[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DzIoCz1FtTFK"
},
"source": [
"The biggest layer is typically\n",
"the one in between the convolutional component\n",
"and the MLP component:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QYrlUprltTFK"
},
"outputs": [],
"source": [
"biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n",
"biggest_layer.shape, cnn.fc_input_dim"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HSHdvEGptTFL"
},
"source": [
"This layer dominates the cost of storing the network on disk.\n",
"That makes it a common target for\n",
"regularization techniques like DropOut\n",
"(as in our architecture)\n",
"and performance optimizations like\n",
"[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n",
"\n",
"Heuristically, we often associated more parameters with more computation.\n",
"But just because that layer has the most parameters\n",
"does not mean that most of the compute time is spent in that layer.\n",
"\n",
"Convolutions reuse the same parameters over and over,\n",
"so the total number of FLOPs done by the layer can be higher\n",
"than that done by layers with more parameters --\n",
"much higher."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YLisj1SptTFL"
},
"outputs": [],
"source": [
"# for the Linear layers, number of multiplications per input == nparams\n",
"cnn.fc1.weight.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yo2oINHRtTFL"
},
"outputs": [],
"source": [
"# for the Conv2D layers, it's more complicated\n",
"\n",
"def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n",
" num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n",
" input_height, input_width = input_size[1], input_size[2]\n",
"\n",
" multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n",
" num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n",
" mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n",
"\n",
" return mutliplications_per_kernel * num_kernels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LwCbZU9PtTFL"
},
"outputs": [],
"source": [
"approx_conv_multiplications(cnn.conv2.conv.weight.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sdco4m9UtTFL"
},
"outputs": [],
"source": [
"# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n",
"approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joVoBEtqtTFL"
},
"source": [
"Depending on your compute hardware and the problem characteristics,\n",
"either the MLP component or the convolutional component\n",
"could become the critical bottleneck.\n",
"\n",
"When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n",
"the MLP component is likely to be the bottleneck,\n",
"whereas when you are compute-constrained, like when running a model on a low-power edge device\n",
"or in an application with strict low-latency requirements,\n",
"the convolutional component is likely to be the bottleneck.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pGSyp67dtTFM"
},
"source": [
"## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AYTJs7snQfX0"
},
"source": [
"We have a model and we have data,\n",
"so we could just go ahead and start training in raw PyTorch,\n",
"[as we did in Lab 01](https://fsdl.me/lab01-colab).\n",
"\n",
"But as we saw in that lab,\n",
"there are good reasons to use a framework\n",
"to organize training and provide fixed interfaces and abstractions.\n",
"So we're going to use PyTorch Lightning, which is\n",
"[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hZYaJ4bdMcWc"
},
"source": [
"We provide a simple script that implements a command line interface\n",
"to training with PyTorch Lightning\n",
"using the models and datasets in this repository:\n",
"`training/run_experiment.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52kIYhPBPLNZ"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rkM_HpILSyC9"
},
"source": [
"The `pl.Trainer` arguments come first\n",
"and there\n",
"[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n",
"so if we want to see what's configurable for\n",
"our `Model` or our `LitModel`,\n",
"we want the last few dozen lines of the help message:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G0dBhgogO8_A"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NCBQekrPRt90"
},
"source": [
"The `run_experiment.py` file is also importable as a module,\n",
"so that you can inspect its contents\n",
"and play with its component functions in a notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CPumvYatPaiS"
},
"outputs": [],
"source": [
"import training.run_experiment\n",
"\n",
"\n",
"print(training.run_experiment.main.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YiZ3RwW2UzJm"
},
"source": [
"Let's run training!\n",
"\n",
"Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n",
"\n",
"This will take several minutes on commodity hardware,\n",
"so feel free to keep reading while it runs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5RSJM5I2TSeG",
"scrolled": true
},
"outputs": [],
"source": [
"gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n",
"\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_ayQ4ByJOnnP"
},
"source": [
"The first thing you'll see are a few logger messages from Lightning,\n",
"then some info about the hardware you have available and are using."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VcMrZcecO1EF"
},
"source": [
"Then you'll see a summary of your model,\n",
"including module names, parameter counts,\n",
"and information about model disk size.\n",
"\n",
"`torchmetrics` show up here as well,\n",
"since they are also `nn.Module`s.\n",
"See [Lab 02a](https://fsdl.me/lab02a-colab)\n",
"for details.\n",
"We're tracking accuracy on training, validation, and test sets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "twGp9iWOUSfc"
},
"source": [
"You may also see a quick message in the terminal\n",
"referencing a \"validation sanity check\".\n",
"PyTorch Lightning runs a few batches of validation data\n",
"through the model before the first training epoch.\n",
"This helps prevent training runs from crashing\n",
"at the end of the first epoch,\n",
"which is otherwise the first time validation loops are triggered\n",
"and is sometimes hours into training,\n",
"by crashing them quickly at the start.\n",
"\n",
"If you want to turn off the check,\n",
"use `--num_sanity_val_steps=0`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jnKN3_MiRpE4"
},
"source": [
"Then, you'll see a bar indicating\n",
"progress through the training epoch,\n",
"alongside metrics like throughput and loss.\n",
"\n",
"When the first (and only) epoch ends,\n",
"the model is run on the validation set\n",
"and aggregate loss and accuracy are reported to the console."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R2eMZz_HR8vV"
},
"source": [
"At the end of training,\n",
"we call `Trainer.test`\n",
"to check performance on the test set.\n",
"\n",
"We typically see test accuracy around 75-80%."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ybpLiKBKSDXI"
},
"source": [
"During training, PyTorch Lightning saves _checkpoints_\n",
"(file extension `.ckpt`)\n",
"that can be used to restart training.\n",
"\n",
"The final line output by `run_experiment`\n",
"indicates where the model with the best performance\n",
"on the validation set has been saved.\n",
"\n",
"The checkpointing behavior is configured using a\n",
"[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n",
"The `run_experiment` script picks sensible defaults.\n",
"\n",
"These checkpoints contain the model weights.\n",
"We can use them to los the model in the notebook and play around with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3Rqh9ZQsY8g4"
},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest checkpoint's filename\n",
"# by hand, you can just copy and paste it\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n",
"filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"latest_ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QW_CxR3coV6"
},
"source": [
"To rebuild the model,\n",
"we need to consider some implementation details of the `run_experiment` script.\n",
"\n",
"We use the parsed command line arguments, the `args`, to build the data and model,\n",
"then use all three to build the `LightningModule`.\n",
"\n",
"Any `LightningModule` can be reinstantiated from a checkpoint\n",
"using the `load_from_checkpoint` method,\n",
"but we'll need to recreate and pass the `args`\n",
"in order to reload the model.\n",
"(We'll see how this can be automated later)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oVWEHcgvaSqZ"
},
"outputs": [],
"source": [
"import training.util\n",
"from argparse import Namespace\n",
"\n",
"\n",
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"CNN\",\n",
" \"data_class\": \"EMNIST\"})\n",
"\n",
"\n",
"_, cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=cnn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MynyI_eUcixa"
},
"source": [
"With the model reloads, we can run it on some sample data\n",
"and see how it's doing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L0HCxgVwcRAA"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(xs) - 1)\n",
"outs = reloaded_model(xs[idx:idx+1])\n",
"\n",
"print(\"output:\", emnist.mapping[torch.argmax(outs)])\n",
"wandb.Image(xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G6NtaHuVdfqt"
},
"source": [
"I generally see subjectively good performance --\n",
"without seeing the labels, I tend to agree with the model's output\n",
"more often than the accuracy would suggest,\n",
"since some classes, like c and C or o, O, and 0,\n",
"are essentially indistinguishable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ZzcDcxpVkki"
},
"source": [
"We can continue a promising training run from the checkpoint.\n",
"Run the cell below to train the model just trained above\n",
"for another epoch.\n",
"Note that the training loss starts out close to where it ended\n",
"in the previous run.\n",
"\n",
"Paired with cloud storage of checkpoints,\n",
"this makes it possible to use\n",
"[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n",
"that can be pre-empted by someone willing to pay more,\n",
"which terminates your job.\n",
"It's also helpful when using Google Colab for more serious projects --\n",
"your training runs are no longer bound by the maximum uptime of a Colab notebook."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "skqdikNtVnaf"
},
"outputs": [],
"source": [
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"\n",
"\n",
"# and we can change the training hyperparameters, like batch size\n",
"%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n",
" --batch_size 64 --load_checkpoint {latest_ckpt}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HBdNt6Z2tTFM"
},
"source": [
"# Creating lines of text from handwritten characters: `EMNISTLines`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FevtQpeDtTFM"
},
"source": [
"We've got a training pipeline for our model and our data,\n",
"and we can use that to make the loss go down\n",
"and get better at the task.\n",
"But the problem we're solving not obviously useful:\n",
"the model is just learning how to handle\n",
"centered, high-contrast, isolated characters.\n",
"\n",
"To make this work in a text recognition application,\n",
"we would need a component to first pull out characters like that from images.\n",
"That task is probably harder than the one we're currently learning.\n",
"Plus, splitting into two separate components is against the ethos of deep learning,\n",
"which operates \"end-to-end\".\n",
"\n",
"Let's kick the realism up one notch by building lines of text out of our characters:\n",
"_synthesizing_ data for our model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dH7i4JhWe7ch"
},
"source": [
"Synthetic data is generally useful for augmenting limited real data.\n",
"By construction we know the labels, since we created the data.\n",
"Often, we can track covariates,\n",
"like lighting features or subclass membership,\n",
"that aren't always available in our labels."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TrQ_44TIe39m"
},
"source": [
"To build fake handwriting,\n",
"we'll combine two things:\n",
"real handwritten letters and real text.\n",
"\n",
"We generate our fake text by drawing from the\n",
"[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n",
"provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n",
"\n",
"First, we download that corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gtSg7Y8Ydxpa"
},
"outputs": [],
"source": [
"from text_recognizer.data.sentence_generator import SentenceGenerator\n",
"\n",
"sentence_generator = SentenceGenerator()\n",
"\n",
"SentenceGenerator.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yal5eHk-aB4i"
},
"source": [
"We can generate short snippets of text from the corpus with the `SentenceGenerator`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eRg_C1TYzwKX"
},
"outputs": [],
"source": [
"print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsBuMICaXnM"
},
"source": [
"We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n",
"and glue them together into images containing the generated text."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YtsGfSu6dpZ9"
},
"outputs": [],
"source": [
"emnist_lines = text_recognizer.data.EMNISTLines() # configure\n",
"emnist_lines.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dik_SyEdb0st"
},
"source": [
"This can take several minutes when first run,\n",
"but afterwards data is persisted to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SofIYHOUtTFM"
},
"outputs": [],
"source": [
"emnist_lines.prepare_data() # download, save to disk\n",
"emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n",
"emnist_lines"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "axESuV1SeoM6"
},
"source": [
"Again, we're using the `LightningDataModule` interface\n",
"to organize our data prep,\n",
"so we can now fetch a batch and take a look at some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1J7f2I9ggBi-"
},
"outputs": [],
"source": [
"line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n",
"line_xs.shape, line_ys.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B0yHgbW2gHgP"
},
"outputs": [],
"source": [
"def read_line_labels(labels):\n",
" return [emnist_lines.mapping[label] for label in labels]\n",
"\n",
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"print(\"-\".join(read_line_labels(line_ys[idx])))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xirEmNPNtTFM"
},
"source": [
"The result looks\n",
"[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n",
"and is not yet anywhere near realistic, even for single lines --\n",
"letters don't overlap, the exact same handwritten letter is repeated\n",
"if the character appears more than once in the snippet --\n",
"but it's a start."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eRWbSzkotTFM"
},
"source": [
"# Applying CNNs to handwritten text: `LineCNNSimple`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pzwYBv82tTFM"
},
"source": [
"The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZqeImjd2lF7p"
},
"outputs": [],
"source": [
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"line_cnn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hi6g0acoxJO4"
},
"source": [
"The `nn.Module`s look much the same,\n",
"but the way they are used is different,\n",
"which we can see by examining the `.forward` method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qg3UJhibxHfC"
},
"outputs": [],
"source": [
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LAW7EWVlxMhd"
},
"source": [
"The `CNN`, which operates on square images,\n",
"is applied to our wide image repeatedly,\n",
"slid over by the `W`indow `S`ize each time.\n",
"We effectively convolve the network with the input image.\n",
"\n",
"Like our synthetic data, it is crude\n",
"but it's enough to get started."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FU4J13yLisiC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = line_cnn(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OxHI4Gzndbxg"
},
"source": [
"> You may notice that this randomly-initialized\n",
"network tends to predict some characters far more often than others,\n",
"rather than predicting all characters with equal likelihood.\n",
"This is a commonly-observed phenomenon in deep networks.\n",
"It is connected to issues with\n",
"[model calibration](https://arxiv.org/abs/1706.04599)\n",
"and Bayesian uses of DNNs\n",
"(see e.g. Figure 7 of\n",
"[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NSonI9KcfJrB"
},
"source": [
"Let's launch a training run with the default parameters.\n",
"\n",
"This cell should run in just a few minutes on typical hardware."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rsbJdeRiwSVA"
},
"outputs": [],
"source": [
"%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n",
" --batch_size 32 --gpus {gpus} --max_epochs 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y9e5nTplfoXG"
},
"source": [
"You should see a test accuracy in the 65-70% range.\n",
"\n",
"That seems pretty good,\n",
"especially for a simple model trained in a minute.\n",
"\n",
"Let's reload the model and run it on some examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0NuXazAvw9NA"
},
"outputs": [],
"source": [
"# if you change around model/data args in the command above, add them here\n",
"# tip: define the arguments as variables, like we've done for gpus\n",
"# and then add those variables to this dict so you don't need to\n",
"# remember to update/copy+paste\n",
"\n",
"args = Namespace(**{\n",
" \"model_class\": \"LineCNNSimple\",\n",
" \"data_class\": \"EMNISTLines\"})\n",
"\n",
"\n",
"_, line_cnn = training.util.setup_data_and_model_from_args(args)\n",
"\n",
"latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n",
"print(latest_ckpt)\n",
"\n",
"reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n",
" latest_ckpt, args=args, model=line_cnn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "J8ziVROkxkGC"
},
"outputs": [],
"source": [
"idx = random.randint(0, len(line_xs) - 1)\n",
"\n",
"outs, = reloaded_lines_model(line_xs[idx:idx+1])\n",
"preds = torch.argmax(outs, 0)\n",
"\n",
"print(\"-\".join(read_line_labels(preds)))\n",
"wandb.Image(line_xs[idx]).image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N9bQCHtYgA0S"
},
"source": [
"In general,\n",
"we see predictions that have very low subjective quality:\n",
"it seems like most of the letters are wrong\n",
"and the model often prefers to predict the most common letters\n",
"in the dataset, like `e`.\n",
"\n",
"Notice, however, that many of the\n",
"characters in a given line are padding characters, `
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 03: Transformers and Paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- The fundamental reasons why the Transformer is such\n",
"a powerful and popular architecture\n",
"- Core intuitions for the behavior of Transformer architectures\n",
"- How to use a convolutional encoder and a Transformer decoder to recognize\n",
"entire paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 3\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XZN4bGgsgWc_"
},
"source": [
"# Why Transformers?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal in building a text recognizer is to take a two-dimensional image\n",
"and convert it into a one-dimensional sequence of characters\n",
"from some alphabet."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional neural networks,\n",
"discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n",
"are great at encoding images,\n",
"taking them from their raw pixel values\n",
"to a more semantically meaningful numerical representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But how do we go from that to a sequence of letters?\n",
"And what's especially tricky:\n",
"the number of letters in an image is separable from its size.\n",
"A screenshot of this document has a much higher density of letters\n",
"than a close-up photograph of a piece of paper.\n",
"How do we get a _variable-length_ sequence of letters,\n",
"where the length need have nothing to do with the size of the input tensor?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n",
"they were\n",
"[originally introduced](https://arxiv.org/abs/1706.03762)\n",
"for transforming one sequence into another,\n",
"as in machine translation.\n",
"This makes them a natural fit for processing language.\n",
"\n",
"But they have also found success in other domains --\n",
"at the time of this writing, large transformers\n",
"dominate the\n",
"[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n",
"that has become a de facto standard for comparing models\n",
"and are finding\n",
"[application in reinforcement learning](https://arxiv.org/abs/2106.01345)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So we will use a Transformer as a key component of our final architecture:\n",
"we will encode our input images with a CNN\n",
"and then read them out into a text sequence with a Transformer.\n",
"\n",
"Before trying out this new model,\n",
"let's first get an understanding of why the Transformer architecture\n",
"has become so popular by walking through its history\n",
"and then get some intuition for how it works\n",
"by looking at some\n",
"[recent work](https://transformer-circuits.pub/)\n",
"on explaining the behavior of both toy models and state-of-the-art language models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kmKqjbvd-Mj3"
},
"source": [
"## Why not convolutions?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SRqkUMdM-OxU"
},
"source": [
"In the ancient beforetimes (i.e. 2016),\n",
"the best models for natural language processing were all\n",
"_recurrent_ neural networks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convolutional networks were also occasionally used,\n",
"but they suffered from a serious issue:\n",
"their architectural biases don't fit text.\n",
"\n",
"First, _translation equivariance_ no longer holds.\n",
"The beginning of a piece of text is often quite different from the middle,\n",
"so the absolute position matters.\n",
"\n",
"Second, _locality_ is not as important in language.\n",
"The name of a character that hasn't appeared in thousands of pages\n",
"can become salient when someone asks, \"Whatever happened to\n",
"[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n",
"\n",
"Consider interpreting a piece of text like the Python code below:\n",
"```python\n",
"def do(arg1, arg2, arg3):\n",
" a = arg1 + arg2\n",
" b = arg3[:3]\n",
" c = a * b\n",
" return c\n",
"\n",
"print(do(1, 1, \"ayy lmao\"))\n",
"```\n",
"\n",
"After a `(` we expect a `)`,\n",
"but possibly very long afterwards,\n",
"[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n",
"and similarly we expect a `]` at some point after a `[`.\n",
"\n",
"For translation variance, consider\n",
"that we interpret `*` not by\n",
"comparing it to its neighbors\n",
"but by looking at `a` and `b`.\n",
"We mix knowledge learned through experience\n",
"with new facts learned while reading --\n",
"also known as _in-context learning_.\n",
"\n",
"In a longer text,\n",
"[e.g. the one you are reading now](./lab03_transformers.ipynb),\n",
"the translation variance of text is clearer.\n",
"Every lab notebook begins with the same header,\n",
"setting up the environment,\n",
"but that header never appears elsewhere in the notebook.\n",
"Later positions need to be processed in terms of the previous entries.\n",
"\n",
"Unlike an image, we cannot simply rotate or translate our \"camera\"\n",
"and get a new valid text.\n",
"[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n",
"that can be read without regard to position."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The field of formal language theory,\n",
"which has deep mutual influence with computer science,\n",
"gives one way of explaining the issues with convolutional networks:\n",
"they can only understand languages with _finite contexts_,\n",
"where all the information can be found within a finite window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The immediate solution, drawing from the connections to computer science, is\n",
"[recursion](https://www.google.com/search?q=recursion).\n",
"A network whose output on the final entry of the sequence is a recursive function\n",
"of all the previous entries can build up knowledge\n",
"as it reads the sequence and treat early entries quite differently than it does late ones."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aa6cbTlImkEh"
},
"source": [
"In pseudo-code, such a _recurrent neural network_ module might look like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lKtBoPnglPrW"
},
"source": [
"```python\n",
"def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n",
" next_inputs = input_module(xs[-1])\n",
" next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n",
" return output_module(next_inputs, next_hiddens)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbJPSMnEm516"
},
"source": [
"If you've had formal computer science training,\n",
"then you may be familiar with the power of recursion,\n",
"e.g. the\n",
"[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n",
"that gave its name to the now much better-known\n",
"[startup incubator](https://www.ycombinator.com/).\n",
"\n",
"The particular form of recursion used by\n",
"recurrent neural networks implements a\n",
"[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n",
"\n",
"> If you've know a lot of computer science,\n",
"you might be concerned by this connection.\n",
"What about other\n",
"[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n",
"Where are the neural network architectures for differentiable\n",
"[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n",
"Check out Graph Neural Networks,\n",
"[which implement dynamic programming](https://arxiv.org/abs/2203.15544)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "63mMTbEBpVuE"
},
"source": [
"Recurrent networks are able to achieve\n",
"[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n",
"\n",
"There are many popular recurrent architectures,\n",
"from the beefy and classic\n",
"[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n",
"and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n",
"([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n",
"all of which have roughly similar capabilities but\n",
"[some of which are easier to train](https://arxiv.org/abs/1611.09913)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PwQHVTIslOku"
},
"source": [
"In the same sense that MLPs can model \"any\" feedforward function,\n",
"in principle even basic RNNs\n",
"[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n",
"\n",
"In particular they can model any\n",
"[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n",
"which is a formal way of saying that they can in principle\n",
"do anything a computer is capable of doing.\n",
"\n",
"The question is then..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3J8EoGN3pu7P"
},
"source": [
"## Why aren't we all using RNNs?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDwNWaevpt_3"
},
"source": [
"The guarantees that MLPs can model any function\n",
"or that RNNs can model Turing machines\n",
"provide decent intuition but are not directly practically useful.\n",
"Among other reasons, they don't guarantee learnability --\n",
"that starting from random parameters we can find the parameters\n",
"that implement a given function.\n",
"The\n",
"[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n",
"than would seem from basic theoretical and empirical analysis.\n",
"\n",
"One way of understanding capacity to model language is\n",
"[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n",
"In this model of formal languages,\n",
"Turing machines sit at the top\n",
"([practically speaking](https://arxiv.org/abs/math/0209332)).\n",
"\n",
"With better mathematical models,\n",
"RNNs and LSTMs can be shown to be\n",
"[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n",
"with RNNs looking more like\n",
"[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n",
"and LSTMs coming in\n",
"[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n",
"\n",
"More controversially:\n",
"the Chomsky hierarchy is great for understanding syntax and grammar,\n",
"which makes it great for building parsers\n",
"and working with formal languages,\n",
"but the goal in _natural_ language processing is to understand _natural_ language.\n",
"Most humans' natural language is far from strictly grammatical,\n",
"but that doesn't mean it is nonsense.\n",
"\n",
"And to really \"understand\" language means\n",
"to understand its semantic content, which is fuzzy.\n",
"The most important thing for handling the fuzzy semantic content\n",
"of language is not whether you can recall\n",
"[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n",
"but whether you can model probabilistic relationships between concepts\n",
"in addition to grammar and syntax."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These both leave theoretical room for improvement over current recurrent\n",
"language and sequence models.\n",
"\n",
"But the real cause of the rise of Transformers is that..."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dsu1ebvAp-3Z"
},
"source": [
"## Transformers are designed to train fast at scale on contemporary hardware."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4abU5adsPGs"
},
"source": [
"The Transformer architecture has several important features,\n",
"discussed below,\n",
"but one of the most important reasons why it is successful\n",
"is because it can be more easily trained at scale.\n",
"\n",
"This scalability is the focus of the discussion in the paper\n",
"that introduced the architecture,\n",
"[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
"and\n",
"[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n",
"\n",
"The recursion in RNNs is inherently sequential:\n",
"the dependence on the outputs from earlier in the sequence\n",
"means computations within an example cannot be parallelized.\n",
"\n",
"So RNNs must batch across examples to scale,\n",
"but as sequence length grows this hits memorybandwidth limits.\n",
"Serving up large batches quickly with good randomness guarantees\n",
"is also hard to optimize,\n",
"especially in distributed settings.\n",
"\n",
"The Transformer architecture,\n",
"on the other hand,\n",
"can be readily parallelized within a single example sequence,\n",
"in addition to parallelization across batches.\n",
"This can lead to massive performance gains for a fixed scale,\n",
"which means larger, higher capacity models\n",
"can be trained on larger datasets."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_Mzk2haFC_G1"
},
"source": [
"How does the architecture achieve this parallelizability?\n",
"\n",
"Let's start with the architecture diagram:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u59eu4snLQfp"
},
"outputs": [],
"source": [
"from IPython import display\n",
"\n",
"base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n",
"\n",
"display.Image(url=base_url + \"/aiayn-figure-1.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ez-XEQ7M0UlR"
},
"source": [
"> To head off a bit of confusion\n",
" in case you've worked with Transformer architectures before:\n",
" the original \"Transformer\" is an encoder/decoder architecture.\n",
" Many LLMs, like GPT models, are decoder only,\n",
" because this has turned out to scale well,\n",
" and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n",
" it's all text anyways.\n",
" We, however, will be using them across modalities,\n",
" so we need an explicit encoder,\n",
" as above. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ok4ksBi4vp89"
},
"source": [
"First focusing on the encoder (left):\n",
"the encoding at a given position is a function of all previous inputs.\n",
"But it is not a function of the previous _encodings_:\n",
"we produce the encodings \"all at once\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RPN7C-_OqzHP"
},
"source": [
"The decoder (right) does use previous \"outputs\" as its inputs,\n",
"but those outputs are not the vectors of layer activations\n",
"(aka embeddings)\n",
"that are produced by the network.\n",
"They are instead the processed outputs,\n",
"after a `softmax` and an `argmax`.\n",
"\n",
"We could obtain these outputs by processing the embeddings,\n",
"much like in a recurrent architecture.\n",
"In fact, that is one way that Transformers are run.\n",
"It's what happens in the `.forward` method\n",
"of the model we'll be training for character recognition:\n",
"`ResnetTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L5_2WMmtDnJn"
},
"source": [
"Let's look at that forward method\n",
"and connect it to the diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FR5pk4kEyCGg"
},
"outputs": [],
"source": [
"from text_recognizer.models import ResnetTransformer\n",
"\n",
"\n",
"ResnetTransformer.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-J5UFDoPzPbq"
},
"source": [
"`.encode` happens first -- that's the left side of diagram.\n",
"\n",
"The encoder can in principle be anything\n",
"that produces a sequence of fixed-length vectors,\n",
"but here it's\n",
"[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n",
"\n",
"Then we start iterating over the sequence\n",
"in the `for` loop.\n",
"\n",
"Focus on the first few lines of code.\n",
"We apply `.decode` (right side of diagram)\n",
"to the outputs so far.\n",
"\n",
"Once we have a new `output`, we apply `.argmax`\n",
"to turn the logits into a concrete prediction of\n",
"a particular token.\n",
"\n",
"This is added as the last output token\n",
"and then the loop happens again."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTcy8-rV1dHr"
},
"source": [
"Run this way, our model looks very much like a recurrent architecture:\n",
"we call the model on its own outputs\n",
"to generate the next value.\n",
"These types of models are also referred to as\n",
"[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n",
"because we predict (as we do in _regression_)\n",
"the next value based on our own (_auto_) output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But Transformers are designed to be _trained_ more scalably than RNNs,\n",
"not necessarily to _run inference_ more scalably,\n",
"and it's actually not the case that our model's `.forward` is called during training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eCxMSAWmEKBt"
},
"source": [
"Let's look at what happens during training\n",
"by checking the `training_step`\n",
"of the `LightningModule`\n",
"we use to train our Transformer models,\n",
"the `TransformerLitModel`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o7q8N7P2w4H"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models import TransformerLitModel\n",
"\n",
"TransformerLitModel.training_step??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1VgNNOjvzC4y"
},
"source": [
"Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tz-6NGPR4dUr"
},
"source": [
"Let's look at `.teacher_forward`,\n",
"and in particular its type signature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ILc2oWET4i2Z"
},
"outputs": [],
"source": [
"TransformerLitModel.teacher_forward??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lf32lpgrDb__"
},
"source": [
"This is known as \"teacher forcing\".\n",
"The \"teacher\" signal is \"forcing\"\n",
"the model to behave as though\n",
"it got the answer right.\n",
"\n",
"[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n",
"It's more effective here\n",
"because the right teaching signal\n",
"for our network is the target data,\n",
"which we have access to during training,\n",
"whereas in an RNN the best teaching signal\n",
"would be the target embedding vector,\n",
"which we do not know.\n",
"\n",
"During inference, when we don't have access to the ground truth,\n",
"we revert to the autoregressive `.forward` method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This \"trick\" allows Transformer architectures to readily scale\n",
"up models to the parameter counts\n",
"[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BAjqpJm9uUuU"
},
"source": [
"## Is there more to Transformers more than just a training trick?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kWCYXeHv7Qc9"
},
"source": [
"[Very](https://arxiv.org/abs/2005.14165),\n",
"[very](https://arxiv.org/abs/1909.08053),\n",
"[very](https://arxiv.org/abs/2205.01068)\n",
"large Transformer models have powered the most recent wave of exciting results in ML, like\n",
"[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n",
"\n",
"They are also the first machine learning models to have come anywhere close to\n",
"deserving the term _artificial intelligence_ --\n",
"a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is surprising because the models and their training procedure are\n",
"(relatively speaking)\n",
"pretty _simple_,\n",
"even if it doesn't feel that way on first pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The basic Transformer architecture is just a bunch of\n",
"dense matrix multiplications and non-linearities --\n",
"it's perhaps simpler than a convolutional architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And advances since the introduction of Transformers in 2017\n",
"have not in the main been made by\n",
"creating more sophisticated model architectures\n",
"but by increasing the scale of the base architecture,\n",
"or if anything making it simpler, as in\n",
"[GPT-type models](https://arxiv.org/abs/2005.14165),\n",
"which drop the encoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V1HQS9ey8GMc"
},
"source": [
"These models are also trained on very simple tasks:\n",
"most LLMs are just trying to predict the next element in the sequence,\n",
"given the previous elements --\n",
"a task simple enough that Claude Shannon,\n",
"father of information theory, was\n",
"[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n",
"\n",
"These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n",
"e.g. by scraping the web."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They are also trained in a simple fashion:\n",
"first-order stochastic optimizers, like SGD or an\n",
"[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n",
"intended for the most basic of optimization problems,\n",
"that scale more readily than the second-order optimizers\n",
"that dominate other areas of optimization."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kz9HPDoy7OAl"
},
"source": [
"This is\n",
"[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n",
"of work in ML:\n",
"simple, even seemingly wasteful,\n",
"architectures that scale well and are robust\n",
"to implementation details\n",
"eventually outstrip more clever but\n",
"also more finicky approaches that are harder to scale.\n",
"This lesson has led some to declare that\n",
"[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n",
"in machine learning, and perhaps even in artificial intelligence."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdN9o2Y771YZ"
},
"source": [
"> That is not to say that because the algorithms are relatively simple,\n",
" training a model at this scale is _easy_ --\n",
" [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n",
" [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n",
" [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n",
" But choosing the simplest algorithm at every step makes solving the scaling problem feasible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baVGf6gKFOvs"
},
"source": [
"The importance of scale is the key lesson from the Transformer architecture,\n",
"far more than any theoretical considerations\n",
"or any of the implementation details.\n",
"\n",
"That said, these large Transformer models are capable of\n",
"impressive behaviors and understanding how they achieve them\n",
"is of intellectual interest.\n",
"Furthermore, like any architecture,\n",
"there are common failure modes,\n",
"of the model and of the modelers who use them,\n",
"that need to be taken into account."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"Below, we'll cover two key intuitions about Transformers:\n",
"Transformers are _residual_, like ResNets,\n",
"and they compose _low rank_ sequence transformations.\n",
"Together, this means they act somewhat like a computer,\n",
"reading from and writing to a \"tape\" or memory\n",
"with a sequence of simple instructions."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1t2Cfq9Fq67Q"
},
"source": [
"We'll also cover a surprising implementation detail:\n",
"despite being commonly used for sequence modeling,\n",
"by default the architecture is _position insensitive_."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uni0VTCr9lev"
},
"source": [
"### Intuition #1: Transformers are highly residual."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0MoBt-JLJz-d"
},
"source": [
"> The discussion of these inuitions summarizes the discussion in\n",
"[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n",
"from\n",
"[Anthropic](https://www.anthropic.com/),\n",
"an AI safety and research company.\n",
"The figures below are from that blog post.\n",
"It is the spiritual successor to the\n",
"[Circuits Thread](https://distill.pub/2020/circuits/)\n",
"covered in\n",
"[Lab 02b](https://lab02b-colab).\n",
"If you want to truly understand Transformers,\n",
"we highly recommend you check it out,\n",
"including the\n",
"[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UUbNVvM5Ferm"
},
"source": [
"It's easy to see that ResNets are residual --\n",
"it's in the name, after all.\n",
"\n",
"But Transformers are,\n",
"in some sense,\n",
"even more closely tied to residual computation\n",
"than are ResNets:\n",
"ResNets and related architectures include downsampling,\n",
"so there is not a direct path from inputs to outputs.\n",
"\n",
"In Transformers, the exact same shape is maintained\n",
"from the moment tokens are embedded,\n",
"through dozens or hundreds of intermediate layers,\n",
"and until they are \"unembedded\" into class logits.\n",
"The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n",
"\n",
"The resiudal stream is easy to see with a change of perspective.\n",
"Instead of the usual architecture diagram above,\n",
"which emphasizes the layers acting on the tensors,\n",
"consider this alternative view,\n",
"which emphasizes the tensors as they pass through the layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HRMlVguKKW6y"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-residual-view.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a9K3N7ilVkB3"
},
"source": [
"For definitions of variables and terms, see the\n",
"[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arvciE-kKd_L"
},
"source": [
"Note that this is a _decoder-only_ Transformer architecture --\n",
"so it should be compared with the right-hand side of the original architecture diagram above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wvrRMd_RKp_G"
},
"source": [
"Notice that outputs of the attention blocks \n",
"and of the MLP layers are\n",
"added to their inputs, as in a ResNet.\n",
"These operations are represented as \"Add & Norm\" layers in the classical diagram;\n",
"normalization is ignored here for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8n_iT-FFAbK"
},
"source": [
"This total commitment to residual operations\n",
"means the size of the embeddings\n",
"(referred to as the \"model dimension\" or the \"embedding dimension\",\n",
"here and below `d_model`)\n",
"stays the same throughout the entire network.\n",
"\n",
"That means, for example,\n",
"that the output of each layer can be used as input to the \"unembedding\" layer\n",
"that produces logits.\n",
"We can read out the computations of intermediate layers\n",
"just by passing them through the unembedding layer\n",
"and examining the logit tensor.\n",
"See\n",
"[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n",
"for detailed experiments and interactive notebooks.\n",
"\n",
"In short, we observe a sort of \"progressive refinement\"\n",
"of the next-token prediction\n",
"as the embeddings proceed, depthwise, through the network."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ovh_3YgY9z2h"
},
"source": [
"### Intuition #2 Transformer heads learn low rank transformations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpNmozlnOdPC"
},
"source": [
"In the original paper and in\n",
"most presentations of Transformers,\n",
"the attention layer is written like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PA7me8gNP5LE"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In pseudo-typed PyTorch (based loosely on\n",
"[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n",
"that looks like:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oeict_6wGJgD"
},
"source": [
"```python\n",
"def classic_attention(\n",
" Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" K: torch.Tensor[\"d_sequence\", \"d_model\"],\n",
" V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" return torch.softmax(Q @ K.T) @ V\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8pewU90DSuOR"
},
"source": [
"This is effectively exactly\n",
"how it is written\n",
"in PyTorch,\n",
"apart from implementation details\n",
"(look for `bmm` for the matrix multiplications and a `softmax` call):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WrgTpKFvOhwc"
},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"F._scaled_dot_product_attention??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ebDXZ0tlSe7g"
},
"source": [
"But the best way to write an operation so that a computer can execute it quickly\n",
"is not necessarily the best way to write it so that a human can understand it --\n",
"otherwise we'd all be coding in assembly.\n",
"\n",
"And this is a strange way to write it --\n",
"you'll notice that what we normally think of\n",
"as the \"inputs\" to the layer are not shown.\n",
"\n",
"We can instead write out the attention layer\n",
"as a function of the inputs $x$.\n",
"We write it for a single \"attention head\".\n",
"Each attention layer includes a number of heads\n",
"that read and write from the residual stream\n",
"simultaneously and independently.\n",
"We also add the output layer weights $W_O$\n",
"and we get:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LuFNR67tQpsf"
},
"outputs": [],
"source": [
"display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVnBjjfOLwxP"
},
"source": [
"or, in pseudo-typed PyTorch:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LmpOm-HfGaNz"
},
"source": [
"```python\n",
"def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n",
" query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n",
" key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n",
" key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n",
" # maps queries of residual stream to keys from residual stream, independent of position\n",
"\n",
" value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n",
" output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n",
" value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n",
" # transformation applied to each token, regardless of position\n",
"\n",
" attention_logits = x.T @ key_query_circuit @ x\n",
" attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n",
" # maps positions to positions, often very sparse\n",
"\n",
" value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n",
"\n",
" return attention_map @ value_output # transformed tokens filtered by attention map\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dC0eqxZ6UAGT"
},
"source": [
"Consider the `key_query_circuit`\n",
"and `value_output_circuit`\n",
"matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n",
"\n",
"The key/query dimension, `d_head`\n",
"is small relative to the model's dimension, `d_model`,\n",
"so $W_{QK}$ and $W_{OV}$ are very low rank,\n",
"[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n",
"that they factorize into two matrices,\n",
"one with a smaller number of rows\n",
"and another with a smaller number of columns.\n",
"That number is called the _rank_.\n",
"\n",
"When computing, these matrices are better represented via their components,\n",
"rather than computed directly,\n",
"which leads to the normal implementation of attention.\n",
"\n",
"In a large language model,\n",
"the ratio of residual stream dimension, `d_model`, to\n",
"the dimension of a single head, `d_head`, is huge, often 100:1.\n",
"That means each query, key, and value computed at a position\n",
"is a fairly simple, low-dimensional feature of the residual stream at that position.\n",
"\n",
"For visual intuition,\n",
"we compare what a matrix with a rank 100th of full rank looks like,\n",
"relative to a full rank matrix of the same size:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LUbojJMiW2C"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"\n",
"low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n",
"full_rank = torch.randn(100, 100)\n",
"plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n",
"plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lqBst92-OVka"
},
"source": [
"The pattern in the first matrix is very simple,\n",
"relative to the pattern in the second matrix."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SkCGrs9EiVh4"
},
"source": [
"Another feature of low rank transformations is\n",
"that they have a large nullspace or kernel --\n",
"these are directions we can move the input without changing the output.\n",
"\n",
"That means that many changes to the residual stream won't affect the behavior of this head at all."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UVz2dQgzhD4p"
},
"source": [
"### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVlzwR03m8mC"
},
"source": [
"The combination of residuality\n",
"(changes are added to the current input)\n",
"and low rank\n",
"(only a small subspace is changed by each head)\n",
"drastically changes the intuition about Transformers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qqjZI2jKe6HH"
},
"source": [
"Rather than being an \"embedding of a token in its context\",\n",
"the residual stream becomes something more like a memory or a scratchpad:\n",
"one layer reads a small bit of information from the stream\n",
"and writes a small bit of information back to it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5YIBkxlqepjc"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/transformer-layer-residual.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RtsKhkLfk00l"
},
"source": [
"The residual stream works like a memory because it is roomy enough\n",
"that these actions need not interfere:\n",
"the subspaces targeted by reads and writes are small relative to the ambient space,\n",
"so they can\n",
"\n",
"Additionally, the dimension of each head is still in the 100s in large models,\n",
"and\n",
"[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n",
"in them, so the number of effectively degrees of freedom is\n",
"actually larger than the dimension.\n",
"This phenomenon allows high-dimensional tensors to serve as\n",
"[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n",
"There are\n",
"[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n",
"\n",
"Together, this means an early layer can write information to the stream\n",
"that can be used by later layers -- by many of them at once, possibly much later.\n",
"Later layers can learn to edit this information,\n",
"e.g. deleting it,\n",
"if doing so reduces the loss,\n",
"but by default the information is preserved."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EragIygzJg86"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-stream-read-write.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oKIaUZjwkpW7"
},
"source": [
"Lastly, the softmax in the attention has a sparsifying effect,\n",
"and so many attention heads are reading from \n",
"just one token and writing to just one other token."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dN6VcJqIMKnB"
},
"outputs": [],
"source": [
"display.Image(url=base_url + \"/residual-token-to-token.png\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Repeatedly reading information from an external memory\n",
"and using it to decide which operation to perform\n",
"and where to write the results\n",
"is at the core of the\n",
"[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n",
"For a concrete example, the\n",
"[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n",
"includes a dissection of a form of \"pointer arithmetic\"\n",
"that appears in some models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0kLFh7Mvnolr"
},
"source": [
"This point of view seems\n",
"very promising for explaining numerous\n",
"otherwise perhaps counterintuitive features of Transformer models.\n",
"\n",
"- This framework predicts lots that Transformers will readily copy-and-paste information,\n",
"which might explain phenomena like\n",
"[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n",
"\n",
"- It also readily explains\n",
"[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n",
"an important component of why Transformers perform well on medium-length texts\n",
"and in few-shot learning.\n",
"\n",
"- Transformers also perform better on reasoning tasks when the text\n",
"[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n",
"is added to their input prompt.\n",
"This is partly due to the fact that that prompt is associated,\n",
"in the dataset, with clearer reasoning,\n",
"and since the models are trained to predict which tokens tend to appear\n",
"after an input, they tend to produce better reasoning with that prompt --\n",
"an explanation purely in terms of sequence modeling.\n",
"But it also gives the Transformer license to generate a large number of tokens\n",
"that act to store intermediate information,\n",
"making for a richer residual stream\n",
"for reading and writing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RyLRzgG-93yB"
},
"source": [
"### Implementation detail: Transformers are position-insensitive by default."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oR6PnrlA_hJ2"
},
"source": [
"In the attention calculation\n",
"each token can query each other token,\n",
"with no regard for order.\n",
"Furthermore, the construction of queries, keys, and values\n",
"is based on the content of the embedding vector,\n",
"which does not automatically include its position.\n",
"\"dog bites man\" and \"man bites dog\" are identical, as in\n",
"[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n",
"\n",
"For most sequences,\n",
"this is unacceptable:\n",
"absolute and relative position matter\n",
"and we cannot use the future to predict the past.\n",
"\n",
"We need to add two pieces to get a Transformer architecture that's usable for next-token prediction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EWHxGJz2-6ZK"
},
"source": [
"First, the simpler piece:\n",
"\"causal\" attention,\n",
"so-named because it ensures that values earlier in the sequence\n",
"are not influenced by later values, which would\n",
"[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0c42xi6URYB4"
},
"source": [
"The most common solution is straightforward:\n",
"we calculate attention between all tokens,\n",
"then throw out non-causal values by \"masking\" them\n",
"(this is before applying the softmax,\n",
"so masking means adding $-\\infty$).\n",
"\n",
"This feels wasteful --\n",
"why are we calculating values we don't need?\n",
"Trying to be smarter would be harder,\n",
"and might rely on operations that aren't as optimized as\n",
"matrix multiplication and addition.\n",
"Furthermore, it's \"only\" twice as many operations,\n",
"so it doesn't even show up in $O$-notation.\n",
"\n",
"A sample attention mask generated by our code base is shown below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NXaWe6pT-9jV"
},
"outputs": [],
"source": [
"from text_recognizer.models import transformer_util\n",
"\n",
"\n",
"attention_mask = transformer_util.generate_square_subsequent_mask(100)\n",
"\n",
"ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n",
"plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n",
"print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This solves our causality problem,\n",
"but we still don't have positional information."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZamUE4WIoGS2"
},
"source": [
"The standard technique\n",
"is to add alternating sines and cosines\n",
"of increasing frequency to the embeddings\n",
"(there are\n",
"[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n",
"most notably\n",
"[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n",
"Each position in the sequence is then uniquely identifiable\n",
"from the pattern of these values.\n",
"\n",
"> Furthermore, for the same reason that\n",
" [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n",
" translations, e.g. relative positions, are fairly easy to express as linear transformations\n",
" of sines and cosines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IDG2uOsaELU0"
},
"source": [
"We superimpose this positional information on our embeddings.\n",
"Note that because the model is residual,\n",
"this position information will be by default preserved\n",
"as it passes through the network,\n",
"so it doesn't need to be repeatedly added."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what this positional encoding looks like in our codebase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Zk62Q-a-1Ax"
},
"outputs": [],
"source": [
"PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n",
"\n",
"pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n",
"\n",
"ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n",
"print(pe[:4, :8])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ep2ClIWvqDms"
},
"source": [
"When we add the positional information to our embeddings,\n",
"both the embedding information and the positional information\n",
"is approximately preserved,\n",
"as can be visually assessed below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PJuFjoCzC0Y4"
},
"outputs": [],
"source": [
"fake_embeddings = torch.randn_like(pe) * 0.5\n",
"\n",
"ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n",
"\n",
"fake_embeddings_with_pe = fake_embeddings + pe\n",
"\n",
"plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n",
"plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHIzBxDkEmH8"
},
"source": [
"A [similar technique](https://arxiv.org/abs/2103.06450)\n",
"is used to also incorporate positional information into the image embeddings,\n",
"which are flattened before being fed to the decoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HC1N85wl8dvn"
},
"source": [
"### Learn more about Transformers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lJwYxkjTk15t"
},
"source": [
"We're only able to give a flavor and an intuition for Transformers here.\n",
"\n",
"To improve your grasp on the nuts and bolts, check out the\n",
"[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n",
"which is surprisingly approachable,\n",
"as far as ML research papers go.\n",
"The\n",
"[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n",
"adds code and commentary to the original paper,\n",
"which makes it even more digestible.\n",
"For something even friendlier, check out the\n",
"[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n",
"by Jay Alammar, which has an accompanying\n",
"[video](https://youtu.be/-QH8fRhqFHM).\n",
"\n",
"Anthropic's work on\n",
"[Transformer Circuits](https://transformer-circuits.pub/),\n",
"summarized above, has some of the best material\n",
"for building theoretical understanding\n",
"and is still being updated with extensions and applications of the framework.\n",
"The\n",
"[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n",
"are a great aid for checking and building your understanding.\n",
"\n",
"But they are fairly math-heavy.\n",
"If you have more of a software engineering background, see\n",
"Transformer Circuits co-author Nelson Elhage's blog post\n",
"[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n",
"\n",
"For a gentler introduction to the intuition for Transformers,\n",
"check out Brandon Rohrer's\n",
"[Transformers From Scratch](https://e2eml.school/transformers.html)\n",
"tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qg7zntJES-aT"
},
"source": [
"An aside:\n",
"the matrix multiplications inside attention dominate\n",
"the big-$O$ runtime of Transformers.\n",
"So trying to make the attention mechanism more efficient, e.g. linear time,\n",
"has generated a lot of research\n",
"(review paper\n",
"[here](https://arxiv.org/abs/2009.06732)).\n",
"Despite drawing a lot of attention, so to speak,\n",
"at the time of writing in mid-2022, these methods\n",
"[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n",
"so it isn't likely to be worth the effort to spend time learning about them\n",
"unless you are a Transformer specialist."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vCjXysEJ8g9_"
},
"source": [
"# Using Transformers to read paragraphs of text"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsfKWnOvqjva"
},
"source": [
"Our simple convolutional model for text recognition from\n",
"[Lab 02b](https://fsdl.me/lab02b-colab)\n",
"could only handle cleanly-separated characters.\n",
"\n",
"It worked by sliding a LeNet-style CNN\n",
"over the image,\n",
"predicting a character for each step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "njLdzBqy-I90"
},
"outputs": [],
"source": [
"import text_recognizer.data\n",
"\n",
"\n",
"emnist_lines = text_recognizer.data.EMNISTLines()\n",
"line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n",
"\n",
"# for sliding, see the for loop over range(S)\n",
"line_cnn.forward??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K0N6yDBQq8ns"
},
"source": [
"But unfortunately for us, handwritten text\n",
"doesn't come in neatly-separated characters\n",
"of equal size, so we trained our model on synthetic data\n",
"designed to work with that model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hiqUVbj0sxLr"
},
"source": [
"Now that we have a better model,\n",
"we can work with better data:\n",
"paragraphs from the\n",
"[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oizsOAcKs-dD"
},
"source": [
"The cell uses our `LightningDataModule`\n",
"to download and preprocess this data,\n",
"writing results to disk.\n",
"We can then spin up `DataLoader`s to give us batches.\n",
"\n",
"It can take several minutes to run the first time\n",
"on commodity machines,\n",
"with most time spent extracting the data.\n",
"On subsequent runs,\n",
"the time-consuming operations will not be repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uL9LHbjdsUbm"
},
"outputs": [],
"source": [
"iam_paragraphs = text_recognizer.data.IAMParagraphs()\n",
"\n",
"iam_paragraphs.prepare_data()\n",
"iam_paragraphs.setup()\n",
"xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n",
"\n",
"iam_paragraphs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBkFN9bbTm_S"
},
"source": [
"Now that we've got a batch,\n",
"let's take a look at some samples:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hqaps8yxtBhU"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"import numpy as np\n",
"import wandb\n",
"\n",
"\n",
"def show(y):\n",
" y = y.detach().cpu() # bring back from accelerator if it's being used\n",
" return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 04: Experiment Management"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How experiment management brings observability to ML model development\n",
"- Which features of experiment management we use in developing the Text Recognizer\n",
"- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 4\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lab contains a large number of embedded iframes\n",
"that benefit from having a wide window.\n",
"The cell below makes the notebook as wide as your browser window\n",
"if `full_width` is set to `True`.\n",
"Full width is the default behavior in Colab,\n",
"so this cell is intended to improve the viewing experience in other Jupyter environments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Why experiment management?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To understand why we need experiment management for ML development,\n",
"let's start by running an experiment.\n",
"\n",
"We'll train a new model on a new dataset,\n",
"using the training script `training/run_experiment.py`\n",
"introduced in [Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use a CNN encoder and Transformer decoder, as in\n",
"[Lab 03](https://fsdl.me/lab03-colab),\n",
"but with some changes so we can iterate faster.\n",
"We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n",
"[Lab02b](https://fsdl.me/lab02b-colab),\n",
"and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n",
"from text_recognizer.data import IAMLines # processed version split into individual lines\n",
"from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n",
"\n",
"\n",
"print(IAM.__doc__)\n",
"\n",
"# uncomment a line below for details on either class\n",
"# IAMLines?? \n",
"# LineCNNTransformer??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below will train a model on 10% of the data for two epochs.\n",
"\n",
"It takes up to a few minutes to run on commodity hardware,\n",
"including data download and preprocessing.\n",
"As it's running, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%%time\n",
"import torch\n",
"\n",
"\n",
"gpus = int(torch.cuda.is_available()) \n",
"\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As the model trains, we're calculating lots of metrics --\n",
"loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n",
"and reporting them to the terminal.\n",
"\n",
"This is achieved by the built-in `.log` method\n",
"([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n",
"of the `LightningModule`,\n",
"and it is a very straightforward way to get basic information about your experiment as it's running\n",
"without leaving the context where you're running it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Learning to read\n",
"[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n",
"is something of a rite of passage for MLEs, but\n",
"let's consider what we can't see here."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We're missing all metric values except the most recent --\n",
"we can see them as they stream in, but they're constantly overwritten.\n",
"We also can't associate them with timestamps, steps, or epochs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- We also don't see any system metrics.\n",
"We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n",
"without launching a separate process.\n",
"And even if we do, those values will also not be saved and timestamped,\n",
"so we can't correlate them with other things during training."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- As we continue to run experiments, changing code and opening new terminals,\n",
"even the information we have or could figure out now will disappear.\n",
"Say you spot a weird error message during training,\n",
"but your session ends and the stdout is gone,\n",
"so you don't know exactly what it was.\n",
"Can you recreate the error?\n",
"Which git branch and commit were you on?\n",
"Did you have any uncommitted changes? Which arguments did you pass?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- Also, model checkpoints containing the parameter values have been saved to disk.\n",
"Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n",
"As we run more and more experiments,\n",
"we'll want to slice and dice them to see if,\n",
"say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to save and log all of this information, and more, in order to make our model training\n",
"[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n",
"in short, so that we can understand, make decisions about, and debug our model training\n",
"by looking at logs and source code, without having to recreate it."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n",
"1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n",
"2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n",
"3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n",
"4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Local Experiment Tracking with Tensorboard"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n",
"and it makes logging very easy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n",
"\n",
"It can also use a logger to push information elsewhere.\n",
"\n",
"By default, we use\n",
"[TensorBoard](https://www.tensorflow.org/tensorboard)\n",
"via the Lightning `TensorBoardLogger`,\n",
"which has been saving results to the local disk.\n",
"\n",
"Let's find them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we use a sequence of bash commands to get the latest experiment's directory\n",
"# by hand, you can just copy and paste it from the terminal\n",
"\n",
"list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n",
"filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n",
"sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n",
"take_first = \"head -n 1\" # the first n elements, n=1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n",
"latest_log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!ls -lh {latest_log}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, we need to launch a TensorBoard server --\n",
"much like we need to launch a Jupyter server to use Jupyter notebooks.\n",
"\n",
"The cells below load an extension that lets you use TensorBoard inside of a notebook\n",
"the same way you'd use it from the command line, and then launch it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n",
"\n",
"port = 11717 # pick an open port on your machine\n",
"host = \"0.0.0.0\" # allow connections from the internet\n",
" # watch out! make sure you turn TensorBoard off\n",
"\n",
"%tensorboard --logdir {latest_log} --port {port} --host {host}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You should see some charts of metrics over time along with some charting controls.\n",
"\n",
"You can click around in this interface and explore it if you'd like,\n",
"but in the next section, we'll see that there are better tools for experiment management."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've run many experiments on this machine,\n",
"you can see all of their results by pointing TensorBoard\n",
"at the whole `lightning_logs` directory,\n",
"rather than just one experiment:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For large numbers of experiments, the management experience is not great --\n",
"it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n",
"\n",
"It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n",
"which are important workflows as applications mature and teams grow."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n",
"\n",
"If you run into any issues with the above cells failing to launch,\n",
"especially across iterations of this lab, run this cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorboard.manager\n",
"\n",
"# get the process IDs for all tensorboard instances\n",
"pids = [tb.pid for tb in tensorboard.manager.get_all()]\n",
"\n",
"done_with_tensorboard = False\n",
"\n",
"if done_with_tensorboard:\n",
" # kill processes\n",
" for pid in pids:\n",
" !kill {pid} 2> /dev/null\n",
" \n",
" # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n",
" !rm -rf {tensorboard.manager._get_info_dir()}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Experiment Management with Weights & Biases"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How do we manage experiments when we hit the limits of local TensorBoard?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is powerful and flexible and very scalable,\n",
"but running it requires engineering effort and babysitting --\n",
"you're running a database, writing data to it,\n",
"and layering a web application over it.\n",
"\n",
"This is a fairly common workflow for web developers,\n",
"but not so much for ML engineers.\n",
"\n",
"You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n",
"and it's as simple as running the command `tensorboard dev upload`\n",
"pointed at your logging directory.\n",
"\n",
"But there are strict limits to this free service:\n",
"1GB of tensor data and 1GB of binary data.\n",
"A single Text Recognizer model checkpoint is ~100MB,\n",
"and that's not particularly large for a useful model.\n",
"\n",
"Furthermore, all data is public,\n",
"so if you upload the inputs and outputs of your model,\n",
"anyone who finds the link can see them.\n",
"\n",
"Overall, tensorboard.dev works very well for certain academic and open projects\n",
"but not for industrial ML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To avoid that narrow permissions and limits issue,\n",
"you could use [git LFS](https://git-lfs.github.com/)\n",
"to track the binary data and tensor data,\n",
"which is more likely to be sensitive than metrics.\n",
"\n",
"The Hugging Face ecosystem uses TensorBoard and git LFS.\n",
"\n",
"It includes the Hugging Face Hub, a git server much like GitHub,\n",
"but designed first and foremost for collaboration on models and datasets,\n",
"rather than collaboration on code.\n",
"For example, the Hugging Face Hub\n",
"[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n",
"and officially has\n",
"[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n",
"avoiding the\n",
"[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n",
"that make using git LFS with GitHub expensive.\n",
"\n",
"However, we prefer to avoid mixing software version control and experiment management.\n",
"\n",
"First, using the Hub requires maintaining an additional git remote,\n",
"which is a hard ask for many engineering teams.\n",
"\n",
"Secondly, git-style versioning is an awkward fit for logging --\n",
"is it really sensible to create a new commit for each logging event while you're watching live?\n",
"\n",
"Instead, we prefer to use systems that solve experiment management with _databases_."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n",
"The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n",
"tool is [MLflow](https://github.com/mlflow/mlflow/)\n",
"and there are a number of\n",
"[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n",
"\n",
"These tools generally avoid any need to worry about hosting\n",
"(unless data governance rules require a self-hosted version).\n",
"\n",
"For a sampling of publicly-posted opinions on experiment management tools,\n",
"see these discussions from Reddit:\n",
"\n",
"- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n",
"- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n",
"\n",
"Among these tools, the FSDL recommendation is\n",
"[Weights & Biases](https://wandb.ai),\n",
"which we believe offers\n",
"- the best user experience, both in the Python SDKs and in the graphical interface\n",
"- the best integrations with other tools,\n",
"including\n",
"[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n",
"[Keras](https://docs.wandb.ai/guides/integrations/keras),\n",
"[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n",
"and even\n",
"[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n",
"and\n",
"- the best tools for collaboration.\n",
"\n",
"Below, we'll take care to point out which logging and management features\n",
"are available via generic interfaces in Lightning and which are W&B-specific."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"print(wandb.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding it to our experiment running code is extremely easy,\n",
"relative to the features we get, which is\n",
"one of the main selling points of W&B.\n",
"\n",
"We get most of our new experiment management features just by changing a single variable, `logger`, from\n",
"`TensorboardLogger` to `WandbLogger`\n",
"and adding two lines of code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll see what each of these lines does for us below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this logger is built into and maintained by PyTorch Lightning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pytorch_lightning.loggers import WandbLogger\n",
"\n",
"\n",
"WandbLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to complete the rest of this notebook,\n",
"you'll need a Weights & Biases account.\n",
"\n",
"As with GitHub the free tier, for personal, academic, and open source work,\n",
"is very generous.\n",
"\n",
"The Text Recognizer project will fit comfortably within the free tier.\n",
"\n",
"Run the cell below and follow the prompts to log in or create an account or go\n",
"[here](https://wandb.ai/signup)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb login"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the cell below to launch an experiment tracked with Weights & Biases.\n",
"\n",
"The experiment can take between 3 and 10 minutes to run.\n",
"In that time, continue reading below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n",
" --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1\n",
" \n",
"last_expt = wandb.run\n",
"\n",
"wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see some new things in our output.\n",
"\n",
"For example, there's a note from `wandb` that the data is saved locally\n",
"and also synced to their servers.\n",
"\n",
"There's a link to a webpage for viewing the logged data and a name for our experiment --\n",
"something like `dandy-sunset-1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The local logging and cloud syncing happens with minimal impact on performance,\n",
"because `wandb` launches a separate process to listen for events and upload them.\n",
"\n",
"That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To view results, head to the link in the notebook output\n",
"that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n",
"\n",
"There's no need to wait for training to finish.\n",
"\n",
"The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(last_expt.url)\n",
"IFrame(last_expt.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have landed on the run page\n",
"([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n",
"which collects up all of the information for a single experiment into a collection of tabs.\n",
"\n",
"We'll work through these tabs from top to bottom.\n",
"\n",
"Each header is also a link to the documentation for a tab."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n",
"This tab has an icon that looks like `(i)` or 🛈.\n",
"\n",
"The top section of this tab has high-level information about our run:\n",
"- Timing information, like start time and duration\n",
"- System hardware, hostname, and basic environment info\n",
"- Git repository link and state\n",
"\n",
"This information is collected and logged automatically.\n",
"\n",
"The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n",
"and summary metrics.\n",
"\n",
"Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n",
"\n",
"This tab has a line plot icon, something like 📈.\n",
"\n",
"It's also the default page you land on when looking at a W&B run.\n",
"\n",
"Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n",
"\n",
"Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n",
"\n",
"We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n",
"This tab has computer chip icon.\n",
"\n",
"It contains\n",
"- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n",
"- CPU metrics: memory usage, utilization, thread counts\n",
"- Disk and network I/O levels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n",
"This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n",
"\n",
"The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n",
"This tab has an icon that looks like a stylized command prompt, `>_`.\n",
"\n",
"It contains information that was printed to the stdout.\n",
"\n",
"This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n",
"\n",
"Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelSummary\" training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n",
"\n",
"For more on Lightning `Callback`s, see\n",
"[Lab 02a](https://fsdl.me/lab02a-colab)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n",
"This tab has a stylized document icon, something like 📄.\n",
"\n",
"You can use this tab to view any files saved with the `wandb.save`.\n",
"\n",
"For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n",
"which we'll discuss shortly.\n",
"\n",
"But a few pieces of information automatically collected by W&B end up in this tab.\n",
"\n",
"Some highlights:\n",
" - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n",
" - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n",
"This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n",
"\n",
"This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n",
"\n",
"We store two kinds of binary files\n",
" - `run_table`s of model inputs and outputs\n",
" - `model` checkpoints\n",
"\n",
"We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interactive Tables of Logged Media"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Returning to the Charts tab,\n",
"notice that we have model inputs and outputs logged in structured tables\n",
"under the train, validation, and test sections.\n",
"\n",
"These tables are interactive as well\n",
"([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n",
"They support basic exploratory data analysis and are compatible with W&B's collaboration features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to charts in our run page, these tables also have their own pages inside the W&B web app."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n",
"table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n",
"\n",
"print(table_data_url)\n",
"IFrame(src=table_data_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Getting this to work requires more effort and more W&B-specific code\n",
"than the other features we've seen so far.\n",
"\n",
"We'll briefly explain the implementation here, for those who are interested.\n",
"\n",
"We use a custom Lightning `Callback`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n",
"\n",
"\n",
"ImageToTextTableLogger??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n",
"\n",
"The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n",
"\n",
"This behavior is sensible for metrics, which are low overhead, but not so much for media,\n",
"where we'd rather subsample and avoid holding on to too much information.\n",
"\n",
"So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n",
"\n",
"The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text_recognizer.lit_models.base import BaseImageToTextLitModel\n",
"\n",
"BaseImageToTextLitModel.add_on_logged_batches??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Projects"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Everything we've seen so far has been related to a single run or experiment.\n",
"\n",
"Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n",
"\n",
"We organize our runs into \"projects\" and view them on the W&B \"project page\" \n",
"([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n",
"\n",
"By default in the Lightning integration, the project name is determined based on directory information.\n",
"This default can be over-ridden in the code when creating a `WandbLogger`,\n",
"but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see what the project page looks like for a longer-running project with lots of experiments.\n",
"\n",
"The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n",
"\n",
"print(project_url)\n",
"IFrame(src=project_url, width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n",
"\n",
"We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Artifacts"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Beyond logging metrics and metadata from runs,\n",
"we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below pulls up all of the artifacts associated with the experiment we just ran."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Click on one of the `model` checkpoints -- the specific version doesn't matter.\n",
"\n",
"There are a number of tabs here.\n",
"\n",
"The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n",
"\n",
"The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n",
"which are added by default by the `WandbLogger`.\n",
"\n",
"The \"Files\" tab contains the actual file contents of the artifact.\n",
"\n",
"On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n",
"including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n",
"\n",
"You can click on these to explore the different versions and even directly compare them.\n",
"\n",
"If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Artifact storage is part of the W&B free tier.\n",
"\n",
"The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n",
"\n",
"The former is sufficient to store ~700 model checkpoints for the Text Recognizer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can track your data storage and compare it to your limits at this URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n",
"\n",
"print(storage_tracker_url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Programmatic Access"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also programmatically access our data and metadata via the `wandb` API\n",
"([docs](https://docs.wandb.ai/guides/track/public-api-guide)):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wb_api = wandb.Api()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run = wb_api.run(\"/\".join( # fetch a run given\n",
" [last_expt.entity, # the user or org it was logged to\n",
" last_expt.project, # the \"project\", usually one of several per repo/application\n",
" last_expt.id] # and a unique ID\n",
"))\n",
"\n",
"hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n",
"\n",
"hist.head(5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hist.groupby(\"epoch\")[\"train/loss\"].mean()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this includes the artifacts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# which artifacts where created and logged?\n",
"artifacts = run.logged_artifacts()\n",
"\n",
"for artifact in artifacts:\n",
" print(f\"artifact of type {artifact.type}: {artifact.name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thanks to our `ImageToTextTableLogger`,\n",
"we can easily recreate training or validation data that came out of our `DataLoader`s,\n",
"which is normally ephemeral:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n",
"artifact_dir = Path(artifact.download(root=\"training/logs\"))\n",
"image_dir = artifact_dir / \"media\" / \"images\"\n",
"\n",
"images = [path for path in image_dir.iterdir()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"from IPython.display import Image\n",
"\n",
"Image(str(random.choice(images)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Advanced W&B API Usage: MLOps"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One of the strengths of a well-instrumented experiment tracking system is that it allows\n",
"automatic relation of information:\n",
"what were the inputs when this model's gradient spiked?\n",
"Which models have been trained on this dataset,\n",
"and what was their performance?\n",
"\n",
"Having access and automation around this information is necessary for \"MLOps\",\n",
"which applies contemporary DevOps principles to ML projects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cells below pull down the training data\n",
"for the model currently running the FSDL Text Recognizer app.\n",
"\n",
"This is just intended as a demonstration of what's possible,\n",
"so don't worry about understanding every piece of this,\n",
"and feel free to skip past it.\n",
"\n",
"MLOps is still a nascent field, and these tools and workflows are likely to change.\n",
"\n",
"For example, just before the course launched, W&B released a\n",
"[Model Registry layer](https://docs.wandb.ai/guides/models)\n",
"on top of artifact logging that aims to improve the developer experience for these workflows."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We start from the same project we looked at in the project view:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n",
"\n",
"text_recognizer_project "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we search it for the text recognizer model currently being used in production:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# collect all versions of the text-recognizer ever put into production by...\n",
"\n",
"for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n",
" if art_type.name == \"prod-ready\": # for the prod-ready type\n",
" # and grabbing the text-recognizer\n",
" production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n",
"\n",
"# and then get the one that's currently being tested in CI by...\n",
"for text_recognizer in production_text_recognizers:\n",
" if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n",
" in_prod_text_recognizer = text_recognizer\n",
"\n",
"# view its metadata at the url or in the notebook\n",
"in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n",
"\n",
"print(in_prod_text_recognizer_url)\n",
"IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From its metadata, we can get information about how it was \"staged\" to be put into production,\n",
"and in particular which model checkpoint was used:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"staging_run = in_prod_text_recognizer.logged_by()\n",
"\n",
"training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n",
"training_ckpt.name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That checkpoint was logged by a training experiment, which is available as metadata.\n",
"\n",
"We can look at the training run for that model, either here in the notebook or at its URL:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"training_run = training_ckpt.logged_by()\n",
"print(training_run.url)\n",
"IFrame(src=training_run.url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And from there, we can access logs and metadata about training,\n",
"confident that we are working with the model that is actually in production.\n",
"\n",
"For example, we can pull down the data we logged and analyze it locally."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"training_results = training_run.history(samples=10000)\n",
"training_results.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n",
"training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 10\n",
"training_results[\"validation/loss\"].dropna().iloc[10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The charts and webpages in Weights & Biases\n",
"are substantially more useful than ephemeral stdouts or raw logs on disk.\n",
"\n",
"If you're spun up on the project,\n",
"they accelerate debugging, exploration, and discovery.\n",
"\n",
"If not, they're not so much useful as they are overwhelming.\n",
"\n",
"We need to synthesize the raw logged data into information.\n",
"This helps us communicate our work with other stakeholders,\n",
"preserve knowledge and prevent repetition of work,\n",
"and surface insights faster.\n",
"\n",
"These workflows are supported by the W&B Reports feature\n",
"([docs here](https://docs.wandb.ai/guides/reports)),\n",
"which mix W&B charts and tables with explanatory markdown text and embeds.\n",
"\n",
"Below are some common report patterns and\n",
"use cases and examples of each."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of the examples are from the FSDL Text Recognizer project.\n",
"You can find more of them\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n",
"where we've organized them into a report!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dashboard Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dashboards are a structured subset of the output from one or more experiments,\n",
"designed for quickly surfacing issues or insights,\n",
"like an accuracy or performance regression\n",
"or a change in the data distribution.\n",
"\n",
"Use cases:\n",
"- show the basic state of ongoing experiment\n",
"- compare one experiment to another\n",
"- select the most important charts so you can spin back up into context on a project more quickly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n",
"\n",
"IFrame(src=dashboard_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pull Request Documentation Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In most software codebases,\n",
"pull requests are a key focal point\n",
"for units of work that combine\n",
"short-term communication and long-term information tracking.\n",
"\n",
"In ML codebases, it's more difficult to bring\n",
"sufficient information together to make PRs as useful.\n",
"At FSDL, we like to add documentary\n",
"reports with one or a small number of charts\n",
"that connect logged information in the experiment management system\n",
"to state in the version control software.\n",
"\n",
"Use cases:\n",
"- communication of results within a team, e.g. code review\n",
"- record-keeping that links pull request pages to raw logged info and makes it discoverable\n",
"- improving confidence in PR correctness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n",
"\n",
"IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Blog Post Report"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With sufficient effort, the logged data in the experiment management system\n",
"can be made clear enough to be consumed,\n",
"sufficiently contextualized to be useful outside the team, and\n",
"even beautiful.\n",
"\n",
"The result is a report that's closer to a blog post than a dashboard or internal document.\n",
"\n",
"Use cases:\n",
"- communication between teams or vertically in large organizations\n",
"- external technical communication for branding and recruiting\n",
"- attracting users or contributors\n",
"\n",
"Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n",
"[Boris Dayma](https://twitter.com/borisdayma)\n",
"and others:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n",
"\n",
"IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hyperparameter Optimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Many of our choices, like the depth of our network, the nonlinearities of our layers,\n",
"and the learning rate and other parameters of our optimizer, cannot be\n",
"([easily](https://arxiv.org/abs/1606.04474))\n",
"chosen by descent of the gradient of a loss function.\n",
"\n",
"But these parameters that impact the values of the parameters\n",
"we directly optimize with gradients, or _hyperparameters_,\n",
"can still be optimized,\n",
"essentially by trying options and selecting the values that worked best.\n",
"\n",
"In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n",
"\n",
"Expending more compute can squeeze small amounts of additional validation or test performance\n",
"that makes for impressive results on leaderboards but typically doesn't translate\n",
"into better user experience.\n",
"\n",
"In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n",
"built into your other tooling.\n",
"\n",
"Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n",
"([docs](https://docs.wandb.ai/guides/sweeps)).\n",
"\n",
"It also supports a number of more advanced tools, like\n",
"[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n",
"for early termination of poorly-performing runs.\n",
"\n",
"We can use the same training script and we don't need to run an optimization server.\n",
"\n",
"We just need to write a configuration yaml file\n",
"([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n",
"like the one below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile training/simple-overfit-sweep.yaml\n",
"# first we specify what we're sweeping\n",
"# we specify a program to run\n",
"program: training/run_experiment.py\n",
"# we optionally specify how to run it, including setting default arguments\n",
"command: \n",
" - ${env}\n",
" - ${interpreter}\n",
" - ${program}\n",
" - \"--wandb\"\n",
" - \"--overfit_batches\"\n",
" - \"1\"\n",
" - \"--log_every_n_steps\"\n",
" - \"25\"\n",
" - \"--max_epochs\"\n",
" - \"100\"\n",
" - \"--limit_test_batches\"\n",
" - \"0\"\n",
" - ${args} # these arguments come from the sweep parameters below\n",
"\n",
"# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n",
"method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n",
"metric:\n",
" name: train/loss\n",
" goal: minimize\n",
"parameters: \n",
" # LineCNN hyperparameters\n",
" window_width:\n",
" values: [8, 16, 32, 64]\n",
" window_stride:\n",
" values: [4, 8, 16, 32]\n",
" # Transformer hyperparameters\n",
" tf_layers:\n",
" values: [1, 2, 4, 8]\n",
" # we can also fix some values, just like we set default arguments\n",
" gpus:\n",
" value: 1\n",
" model_class:\n",
" value: LineCNNTransformer\n",
" data_class:\n",
" value: IAMLines\n",
" loss:\n",
" value: transformer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the config we launch a \"controller\":\n",
"a lightweight process that just decides what hyperparameters to try next\n",
"and coordinates the heavierweight training.\n",
"\n",
"This lives on the W&B servers, so there are no headaches about opening ports for communication,\n",
"cleaning up when it's done, etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n",
"simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and then we can launch an \"agent\" to follow the orders of the controller:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# interrupt twice to terminate this cell if it's running too long,\n",
"# it can be over 15 minutes with some hyperparameters\n",
"\n",
"!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n",
"\n",
"If not provided, the agent will run forever for random or Bayesian sweeps\n",
"or until the sweep is terminated, which can be done from the W&B interface."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The agents make for a slick workflow for distributing sweeps across GPUs.\n",
"\n",
"We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n",
"which controls which GPUs are accessible by a process, to launch\n",
"parallel agents on separate GPUs on the same machine."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n",
"# open another terminal\n",
"CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n",
"# and so on\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We include optional exercises with the labs for learners who want to dive deeper on specific topics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟Contribute to a hyperparameter search."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n",
"\n",
"There are ~10,000,000 potential hyperparameter combinations,\n",
"and each takes 30 minutes to test,\n",
"so checking each possibility will take over 500 years of compute time.\n",
"Best get cracking then!\n",
"\n",
"Run the cell below to pull up a dashboard and print the URL where you can check on the current status."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_entity = \"fullstackdeeplearning\"\n",
"sweep_project = \"fsdl-line-recognizer-2022\"\n",
"sweep_id = \"e0eo43eu\"\n",
"sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n",
"\n",
"print(sweep_url)\n",
"IFrame(src=sweep_url, width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also retrieve information about the sweep from the API,\n",
"including the hyperparameters being swept over."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hyperparams = sweep_info.config[\"parameters\"]\n",
"hyperparams"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you'd like to contribute to this sweep,\n",
"run the cell below after changing the count to a number greater than 0.\n",
"\n",
"Each iteration runs for 30 minutes if it does not crash,\n",
"e.g. due to out-of-memory errors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"count = 0 # off by default, increase it to join in!\n",
"\n",
"if count:\n",
" !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Write some manual logging in `wandb`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the FSDL Text Recognizer codebase,\n",
"we almost exclusively log to W&B through Lightning,\n",
"rather than through the `wandb` Python SDK.\n",
"\n",
"If you're interested in learning how to use W&B directly, e.g. with another training framework,\n",
"try out this quick exercise that introduces the key players in the SDK."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n",
"\n",
"It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n",
"\n",
"Add W&B metric and artifact logging to this cell:\n",
"- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n",
"- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import os\n",
"import random\n",
"\n",
"import wandb\n",
"\n",
"\n",
"os.makedirs(\"logs\", exist_ok=True)\n",
"\n",
"project = \"trying-wandb\"\n",
"config = {\"steps\": 50}\n",
"\n",
"\n",
"with wandb.init(project=project, config=config) as run:\n",
" steps = wandb.config[\"steps\"]\n",
" \n",
" for ii in range(steps):\n",
" loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n",
" \n",
" with open(\"logs/hello.txt\", \"w\") as f:\n",
" f.write(\"hello from wandb, my dudes!\")\n",
" \n",
" run_id = run.id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hello_run = wb_api.run(f\"{project}/{run_id}\")\n",
"\n",
"# check for logged loss data\n",
"if \"loss\" not in hello_run.history().keys():\n",
" print(\"loss not logged 🥲\")\n",
"else:\n",
" print(\"loss logged successfully 🥞\")\n",
" if len(hello_run.history()[\"loss\"]) != steps:\n",
" print(\"loss not logged on all steps 🥲\")\n",
" else:\n",
" print(\"loss logged on all steps 🥞\")\n",
"\n",
"artifacts = hello_run.logged_artifacts()\n",
"\n",
"# check for artifact with the right name\n",
"if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n",
" print(\"hello artifact not logged 🥲\")\n",
"else:\n",
" print(\"hello artifact logged successfully 🥞\")\n",
" # check for the file inside the artifacts\n",
" if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n",
" print(\"could not find hello.txt 🥲\")\n",
" else:\n",
" print(\"hello.txt logged successfully 🥞\")\n",
" \n",
" \n",
"hello_run"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you observe interesting phenomena during training,\n",
"from promising hyperparameter combos to software bugs to strange model behavior,\n",
"turn the charts into a W&B report and share it with the FSDL community or\n",
"[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n",
"with a link to them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# check the sweep_info.config above to see the model and data hyperparameters\n",
"# read through the --help output for all potential arguments\n",
"%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n",
" --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n",
" --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n",
" --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n",
" --help # remove this line to run an experiment instead of printing help\n",
" \n",
"last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🌟🌟🌟 Add logging of tensor statistics."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In addition to logging model inputs and outputs as human-interpretable media,\n",
"it's also frequently useful to see information about their numerical values."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're interested in learning more about metric calculation and logging with Lightning,\n",
"use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n",
"to add tensor statistic logging to the `LineCNNTransformer`.\n",
"\n",
"`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n",
"\n",
"All three are useful, but start by adding just one."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n",
"- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n",
" - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n",
"- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n",
" - Base your code on the calculation and logging of the `val_cer` metric.\n",
" - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ",
"collapsed_sections": [],
"name": "lab04_experiments.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab08/notebooks/lab05_troubleshooting.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 05: Troubleshooting & Testing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n",
"- How to implement tests for ML training systems in particular\n",
"- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 5\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # allow \"hot-reloading\" of modules\n",
" %load_ext autoreload\n",
" %autoreload 2\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
" \n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sThWeTtV6fL_"
},
"outputs": [],
"source": [
"from IPython.display import display, HTML, IFrame\n",
"\n",
"full_width = True\n",
"frame_height = 720 # adjust for your screen\n",
"\n",
"if full_width: # if we want the notebook to take up the whole width\n",
" # add styling to the notebook's HTML directly\n",
" display(HTML(\"\"))\n",
" display(HTML(\"\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xFP8lU4nSg1P"
},
"source": [
"# Linting Python and Shell Scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cXbdYfFlPhZ-"
},
"source": [
"### Automatically linting with `pre-commit`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ysqqb2GjvLrz"
},
"source": [
"We want keep our code clean and uniform across developers\n",
"and time.\n",
"\n",
"Applying the cleanliness checks and style rules should be\n",
"as painless and automatic as possible.\n",
"\n",
"For this purpose, we recommend bundling linting tools together\n",
"and enforcing them on all commits with\n",
"[`pre-commit`](https://pre-commit.com/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvqtZChKvLr0"
},
"source": [
"In addition to running on every commit,\n",
"`pre-commit` separates the model development environment from the environments\n",
"needed for the linting tools, preventing conflicts\n",
"and simplifying maintenance and onboarding."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0XuIuKOXhJl"
},
"source": [
"This cell runs `pre-commit`.\n",
"\n",
"The first time it is run on a machine, it will install the environments for all tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hltYGbpNvLr1"
},
"outputs": [],
"source": [
"!pre-commit run --all-files"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gLw08gIkvLr1"
},
"source": [
"The output lists all the checks that are run and whether they are passed.\n",
"\n",
"Notice there are a number of simple version-control hygiene practices included\n",
"that aren't even specific to Python, much less to machine learning.\n",
"\n",
"For example, several of the checks prevent accidental commits with private keys, large files, \n",
"leftover debugger statements, or merge conflict annotations in them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RHEEjb9kvLr1"
},
"source": [
"These linting actions are configured via\n",
"([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n",
"a YAML file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dgXa8BzrvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8HYc_WbTvLr2"
},
"source": [
"Most of the general cleanliness checks are from hooks built by `pre-commit`.\n",
"\n",
"See the comments and links in the `.pre-commit-config.yaml` for more:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K9rTgRqzvLr2"
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep repos -A 15"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ptkO7aPvLr2"
},
"source": [
"Let's take a look at the section of the file\n",
"that applies most of our Python style enforcement with\n",
"[`flake8`](https://flake8.pycqa.org/en/latest/):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ALsRKfcevLr3",
"scrolled": true
},
"outputs": [],
"source": [
"!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_Q0BwQUXbg6"
},
"source": [
"The majority of the style checking behavior we want comes from the\n",
"`additional_dependencies`, which are\n",
"[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n",
"that extend `flake8`'s list of lints.\n",
"\n",
"Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n",
"\n",
"We keep the configuration information for `flake8`\n",
"separate from that for `pre-commit`\n",
"in case we want to use additional tools with `flake8`,\n",
"e.g. if some developers want to integrate it directly into their editor,\n",
"and so that if we change away from `.pre-commit`\n",
"but keep `flake8` we don't have to\n",
"recreate our configuration in a different tool.\n",
"\n",
"As much as possible, codebases should strive for single sources of truth\n",
"and link back to those sources of truth with documentation or comments,\n",
"as in the last line above.\n",
"\n",
"Let's take a look at the contents of `flake8`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "doC_4WQwvLr3"
},
"outputs": [],
"source": [
"!cat .flake8"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Nq6HnyU0M47"
},
"source": [
"There's a lot here! We'll focus on the most important bits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U4PiB8CPvLr3"
},
"source": [
"Linting tools in Python generally work by emitting error codes\n",
"with one or more letters followed by three numbers.\n",
"The `select` argument picks which error codes we want to check for.\n",
"Error codes are matched by prefix,\n",
"so for example `B` matches `BTS101` and\n",
"`G1` matches `G102` and `G199` but not `ARG404`.\n",
"\n",
"Certain codes are `ignore`d in the default `flake8` style,\n",
"which is done via the `ignore` argument,\n",
"and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n",
"For example, we rely on `black` to do our formatting,\n",
"so we ignore some of `flake8`'s formatting codes.\n",
"\n",
"Together, these settings define our project's particular style.\n",
"\n",
"But not every file fits this style perfectly.\n",
"Most of the conventions in `black` and `flake8` come from the style-defining\n",
"[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n",
"which exhorts you to \"know when to be inconsistent\".\n",
"\n",
"To allow ourselves to be inconsistent when we know we should be,\n",
"`flake8` includes `per-file-ignores`,\n",
"which let us ignore specific warnings in specific files.\n",
"This is one of the \"escape valves\"\n",
"that makes style enforcement tolerable.\n",
"We can also `exclude` files in the `pre-commit` config itself.\n",
"\n",
"For details on selecting and ignoring,\n",
"see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n",
"\n",
"For definitions of the error codes from `flake8` itself,\n",
"see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n",
"Individual extensions list their added error codes in their documentation,\n",
"e.g. `darglint` does so\n",
"[here](https://github.com/terrencepreilly/darglint#error-codes)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NL0TpyPsvLr4"
},
"source": [
"The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n",
"\n",
"You can read more about each in their documentation:\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mFsZC0a7vLr4"
},
"source": [
"### Linting via a script and using `shellcheck`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RYjpuFwjXkJc"
},
"source": [
"To avoid needing to think about `pre-commit`\n",
"(was the command `pre-commit run` or `pre-commit check`?)\n",
"while developing locally,\n",
"we might put our linters into a shell script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mXlLFWmavLr4"
},
"outputs": [],
"source": [
"!cat tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PPxHpRIB3nbw"
},
"source": [
"These kinds of short and simple shell scripts are common in projects\n",
"of intermediate size.\n",
"\n",
"They are useful for adding automation and reducing friction."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TMuPBpAi2qwl"
},
"source": [
"But these scripts are code,\n",
"and all code is susceptible to bugs and subject to concerns of style consistency."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SQRg3ZqXvLr4"
},
"source": [
"We can't check these scripts with tools that lint Python code,\n",
"so we include a shell script linting tool,\n",
"[`shellcheck`](https://www.shellcheck.net/),\n",
"in our `pre-commit`.\n",
"\n",
"More so than checking for correct style,\n",
"this tool checks for common bugs or surprising behaviors of shells,\n",
"which are unfortunately numerous."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zkfhE1srvLr4"
},
"outputs": [],
"source": [
"script_filename = \"tasks/lint.sh\"\n",
"!pre-commit run shellcheck --files {script_filename}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXU9TRrwvLr4"
},
"source": [
"That script has already been tested, so we don't see any errors.\n",
"\n",
"Try copying over a script you've written yourself or\n",
"even from a popular repo that you like\n",
"(by adding to the notebook directory or by making a cell\n",
"with `%%writefile` at the top)\n",
"and test it by changing the `script_filename`.\n",
"\n",
"You'd be surprised at the classes of subtle bugs possible in bash!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "81MhAL-TvLr5"
},
"source": [
"### Try \"unofficial bash strict mode\" for louder failures in scripts"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hSwhs_zUvLr5"
},
"source": [
"Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n",
"[@redsymbol](https://twitter.com/redsymbol),\n",
"which appear at the top of the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o-j0vSxEvLr5"
},
"outputs": [],
"source": [
"!head -n 3 tasks/lint.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d2iJU5jlvLr5"
},
"source": [
"The core idea of strict mode is to fail more loudly.\n",
"This is a desirable behavior of scripts,\n",
"like the ones we're writing,\n",
"even though it's an undesirable behavior for an interactive shell --\n",
"it would be unpleasant to be logged out every time you hit an error.\n",
"\n",
"`set -u` means scripts fail if a variable's value is `u`nset,\n",
"i.e. not defined.\n",
"Otherwise bash is perfectly happy to allow you to reference undefined variables.\n",
"The result is just an empty string, which can lead to maddeningly weird behavior.\n",
"\n",
"`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n",
"rather than using the exit code of the last command.\n",
"Unix tools are perfectly happy to work on nonsense input,\n",
"like sorting error messages, instead of the filenames you meant to send.\n",
"\n",
"You can read more about these choices\n",
"[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n",
"and considerations for working with other non-conforming scripts in \"strict mode\"\n",
"and for handling resource teardown when scripts error out."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s1XqsrU_XWWS"
},
"source": [
"# Testing ML Codebases"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPNzeq3NYF2W"
},
"source": [
"## Testing Python code with `pytests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zq5e_x6gc9Vu"
},
"source": [
"\n",
"ML codebases are Python first and foremost, so first let's get some Python tests going."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DC3GxYz6_R9"
},
"source": [
"At a basic level,\n",
"we can write functions that `assert`\n",
"that our code behaves as expected in\n",
"a given scenario and include it in the same module."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rvd-GNwv63W1"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.metrics import test_character_error_rate\n",
"\n",
"test_character_error_rate??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iVB2TsQS5BTq"
},
"source": [
"The standard tool for testing Python code is\n",
"[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n",
"\n",
"We can use it as a command-line tool in a variety of ways,\n",
"including to execute these kinds of tests.\n",
"\n",
"If passed a filename, `pytest` will look for\n",
"any classes that start with `Test` or\n",
"any functions that start with `test_` and run them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u8sQguyJvLr6",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest text_recognizer/lit_models/metrics.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "92tkBCllvLr6"
},
"source": [
"After the results of the tests (pass or fail) are returned,\n",
"you'll see a report of \"coverage\" from\n",
"[`codecov`](https://about.codecov.io/).\n",
"\n",
"This coverage report tells us which files and how many lines in those files\n",
"were at touched by the testing suite."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PllSUe0s5xvU"
},
"source": [
"We do not actually need to provide the names of files with tests in them to `pytest`\n",
"in order for it to run our tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4qOBHJnTZM9x"
},
"source": [
"By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n",
"\n",
"It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n",
"to separate these from the rest of your code\n",
"in a folder or folders named `tests`,\n",
"rather than scattering them around the repo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "acjsYTNSvLr6"
},
"outputs": [],
"source": [
"!ls text_recognizer/tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WZQQZUF0vLr6"
},
"source": [
"Let's take a look at a specific example:\n",
"the tests for some of our utilities around\n",
"custom PyTorch Lightning `Callback`s."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oS0xKv1evLr6"
},
"outputs": [],
"source": [
"from text_recognizer.tests import test_callback_utils\n",
"\n",
"\n",
"test_callback_utils.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lko8msn-vLr7"
},
"source": [
"Notice that we can easily import this as a module!\n",
"\n",
"That's another benefit of organizing tests into specialized files."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5A85FUNv75Fr"
},
"source": [
"The particular utility we're testing\n",
"here is designed to prevent crashes:\n",
"it checks for a particular type of error and turns it into a warning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jl4-DiVe76sw"
},
"outputs": [],
"source": [
"from text_recognizer.callbacks.util import check_and_warn\n",
"\n",
"check_and_warn??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B6E0MhduvLr7"
},
"source": [
"Error-handling code is a common cause of bugs,\n",
"a fact discovered\n",
"[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n",
"so it's very important to test it well!\n",
"\n",
"We start with a very basic test,\n",
"which does not touch anything\n",
"outside of the Python standard library,\n",
"even though this tool is intended to be used\n",
"with more complex features of third-party libraries,\n",
"like `wandb` and `tensorboard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xx5koQmJvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_simple??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MZe9-JVjvLr7"
},
"source": [
"Here, we are just testing the core logic.\n",
"This test won't catch many bugs,\n",
"but when it does fail, something has gone seriously wrong.\n",
"\n",
"These kinds of tests are important for resolving a bug:\n",
"we learn nearly as much from the tests that passed\n",
"as we did from the tests that failed.\n",
"If this test has failed, possibly along with others,\n",
"we can rule out an issue in one of the large external codebases\n",
"touched in the other tests, saving us lots of time in our troubleshooting.\n",
"\n",
"The reasoning for the test is explained in the docstrings, \n",
"which are close to the code.\n",
"\n",
"Your test suite should be as welcoming\n",
"as the rest of your codebase!\n",
"The people reading it, for example yourself in six months, \n",
"are likely upset and in need of some kindness.\n",
"\n",
"More practically, we want keep our time to resolve errors as short as possible,\n",
"and five minutes to write a good docstring now\n",
"can save five minutes during an outage, when minutes really matter."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Om9k-uXhvLr7"
},
"source": [
"That basic test is a start, but it's not enough by itself.\n",
"There's a specific error case that triggered the addition of this code.\n",
"\n",
"So we test that it's handled as expected."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fjbsb5FvvLr7"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_tblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CGAIZTUjvLr7"
},
"source": [
"That test can fail if the libraries change around our code,\n",
"i.e. if the `TensorBoardLogger` gets a `log_table` method.\n",
"\n",
"We want to be careful when making assumptions\n",
"about other people's software,\n",
"especially for fast-moving libraries like Lightning.\n",
"If we test that those assumptions hold willy-nilly,\n",
"we'll end up with tests that fail because of\n",
"harmless changes in our dependencies.\n",
"\n",
"Tests that require a ton of maintenance and updating\n",
"without leading to code improvements soak up\n",
"more engineering time than they save\n",
"and cause distrust in the testing suite.\n",
"\n",
"We include this test because `TensorBoardLogger` getting\n",
"a `log_table` method will _also_ change the behavior of our code\n",
"in a breaking way, and we want to catch that before it breaks\n",
"a model training job."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jsy95KAvvLr7"
},
"source": [
"Adding error handling can also accidentally kill the \"happy path\"\n",
"by raising an error incorrectly.\n",
"\n",
"So we explicitly test the _absence of an error_,\n",
"not just its presence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LRlIOkjmvLr8"
},
"outputs": [],
"source": [
"test_callback_utils.test_check_and_warn_wandblogger??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osiqpLynvLr8"
},
"source": [
"There are more tests we could build, e.g. manipulating classes and testing the behavior,\n",
"testing more classes that might be targeted by `check_and_warn`, or\n",
"asserting that warnings are raised to the command line.\n",
"\n",
"But these three basic tests are likely to catch most changes that would break our code here,\n",
"and they're a lot easier to write than the others.\n",
"\n",
"If this utility starts to get more usage and become a critical path for lots of features, we can always add more!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dm285JE5vLr8"
},
"source": [
"## Interleaving testing and documentation with `doctests`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UHWQvgA8vLr8"
},
"source": [
"One function of tests is to build user/reader confidence in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wrhiJBXFvLr8"
},
"source": [
"One function of documentation is to build user/reader knowledge in code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1vu12LDhvLr8"
},
"source": [
"These functions are related. Let's put them together:\n",
"put code in a docstring and test that code.\n",
"\n",
"This feature is part of the\n",
"Python standard library via the\n",
"[`doctest` module](https://docs.python.org/3/library/doctest.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rmfIOwXd-Qt7"
},
"source": [
"Here's an example from our `torch` utilities.\n",
"\n",
"The `first_appearance` function can be used to\n",
"e.g. quickly look for stop tokens,\n",
"giving the length of each sequence."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZzURGcD9vLr8"
},
"outputs": [],
"source": [
"from text_recognizer.lit_models.util import first_appearance\n",
"\n",
"\n",
"first_appearance??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0VtYcJ1WvLr8"
},
"source": [
"Notice that in the \"Examples\" section,\n",
"there's a short block of code formatted as a\n",
"Python interpreter session,\n",
"complete with outputs.\n",
"\n",
"We can copy and paste that code and\n",
"check that we get the right outputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Dj4lNOxJvLr9"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y9AWHFoIvLr9"
},
"source": [
"We can run the test with `pytest` by passing a command line argument,\n",
"`--doctest-modules`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JMaAxv5ovLr9"
},
"outputs": [],
"source": [
"!pytest --doctest-modules text_recognizer/lit_models/util.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6-2_aOUfvLr9"
},
"source": [
"With the\n",
"[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n",
"running `doctest`s happens automatically\n",
"when `pytest` is invoked."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "my_keokPvLr9"
},
"source": [
"## Basic tests for data code"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qj3Bq_j2_A8o"
},
"source": [
"ML code can be hard to test\n",
"since it involes very heavy artifacts, like models and data,\n",
"and very expensive jobs, like training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DT5OmgrQvLr9"
},
"source": [
"For testing our data-handling code in the FSDL codebase,\n",
"we mostly just use `assert`s,\n",
"which throw errors when behavior differs from expectation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Bdzn5g4TvLr9"
},
"outputs": [],
"source": [
"!grep \"assert\" -r text_recognizer/data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2aTlfu4_vLr-"
},
"source": [
"This isn't great practice,\n",
"especially as a codebase grows,\n",
"because we can't easily know when these are executed\n",
"or incorporate them into\n",
"testing automation and coverage analysis tools."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IaMTdmbZ_mkW"
},
"source": [
"So it's preferable to collect up these assertions of simple data properties\n",
"into tests that are run like our other tests.\n",
"\n",
"The test below checks whether any data is leaking\n",
"between training, validation, and testing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qx7cxiDdvLr-"
},
"outputs": [],
"source": [
"from text_recognizer.tests.test_iam import test_iam_data_splits\n",
"\n",
"\n",
"test_iam_data_splits??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16TJwhd1vLr-"
},
"source": [
"Notice that we were able to load the test into the notebook\n",
"because it is in a module,\n",
"and so we can run it here as well:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mArITFkYvLr-"
},
"outputs": [],
"source": [
"test_iam_data_splits()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E4F2uaclvLr-"
},
"source": [
"But we're checking something pretty simple here,\n",
"so the new code in each test is just a single line.\n",
"\n",
"What if we wanted to test more complex properties,\n",
"like comparing rows or calculating statistics?\n",
"\n",
"We'll end up writing more complex code that might itself have subtle bugs,\n",
"requiring tests for our tests and suffering from\n",
"\"tester's regress\".\n",
"\n",
"This is the phenomenon,\n",
"named by analogy with\n",
"[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n",
"in sociology of science,\n",
"where the validity of our tests is itself\n",
"up for dispute only resolvable by testing the tests,\n",
"but those tests are themselves possibly invalid."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nUGT06gdvLr-"
},
"source": [
"We cut this Gordian knot by using\n",
"a library or framework that is well-tested.\n",
"\n",
"We recommend checking out\n",
"[`great_expectations`](https://docs.greatexpectations.io/docs/)\n",
"if you're looking for a high-quality data testing tool."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dQ5vNsq3vLr-"
},
"source": [
"Especially with data, some tests are particularly \"heavy\" --\n",
"they take a long time,\n",
"and we might want to run them\n",
"on different machines\n",
"and on a different schedule\n",
"than our other tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xephcb0LvLr-"
},
"source": [
"For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n",
"\n",
"We can't just use a cached version of the data,\n",
"since that won't actually execute the code!\n",
"\n",
"This test will take\n",
"as long to run\n",
"and consume as many resources as\n",
"a full download of the data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YSN4w2EqvLr-"
},
"source": [
"`pytest` allows the separation of tests\n",
"into suites with `mark`s,\n",
"which \"tag\" tests with names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V0rScrcXvLr_",
"scrolled": false
},
"outputs": [],
"source": [
"!pytest --markers | head -n 10"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lr5Ca7B0vLr_"
},
"source": [
"We can choose to run tests with a given mark\n",
"or to skip tests with a given mark, \n",
"among other basic logical operations around combining and filtering marks,\n",
"with `-m`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xmw-Eb1ZvLr_"
},
"outputs": [],
"source": [
"!wandb login # one test requires wandb authentication\n",
"\n",
"!pytest -m \"not data and not slow\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5LuERxOXX_UJ"
},
"source": [
"## Testing training with memorization tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AnWLN4lRvLsA"
},
"source": [
"Training is the process by which we convert inert data into executable models,\n",
"so it is dependent on both.\n",
"\n",
"We decouple checking whether the script has a critical bug\n",
"from whether the data or model code is broken\n",
"by testing on some basic \"fake data\",\n",
"based on a utility from `torchvision`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k4NIc3uWvLsA"
},
"outputs": [],
"source": [
"from text_recognizer.data import FakeImageData\n",
"\n",
"\n",
"FakeImageData.__doc__"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "deN0swwlvLsA"
},
"source": [
"We then test on the actual data with a smaller version of the real model.\n",
"\n",
"We use the Lightning `--fast_dev_run` feature,\n",
"which sets the number of training, validation, and test batches to `1`.\n",
"\n",
"We use a smaller version so that this test can run in just a few minutes\n",
"on a CPU without acceleration.\n",
"\n",
"That allows us to run our tests in environments without GPUs,\n",
"which saves on costs for executing tests.\n",
"\n",
"Here's the script:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z4J0_uD9vLsA"
},
"outputs": [],
"source": [
"!cat training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-7u9zS1vLsA",
"scrolled": false
},
"outputs": [],
"source": [
"! ./training/tests/test_run_experiment.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UTzfo11KClV3"
},
"source": [
"The above tests don't actaully check\n",
"whether any learning occurs,\n",
"they just check\n",
"whether training runs mechanically,\n",
"without any errors.\n",
"\n",
"We also need a\n",
"[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n",
"for learning.\n",
"For that we recommending checking whether\n",
"the model can learn the right\n",
"outputs for a single batch --\n",
"to \"memorize\" the outputs for\n",
"a particular input.\n",
"\n",
"This memorization test won't\n",
"catch every bug or issue in training,\n",
"which is notoriously difficult,\n",
"but it will flag\n",
"some of the most serious issues."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0DVSp3aAvLsA"
},
"source": [
"The script below runs a memorization test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2DFVVrxpvLsA"
},
"source": [
"It takes up to two arguments:\n",
"a `MAX`imum number of `EPOCHS` to run for and\n",
"a `CRITERION` value of the loss to test against.\n",
"\n",
"The test passes if the loss is lower than the `CRITERION` value\n",
"after the `MAX`imum number of `EPOCHS` has passed."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oEhJH0e5vLsB"
},
"source": [
"The important line in this script is the one that invokes our training script,\n",
"`training/run_experiment.py`.\n",
"\n",
"The arguments to `run_experiment` have been tuned for maximum possible speed:\n",
"turning off regularization, shrinking the model,\n",
"and skipping parts of Lightning that we don't want to test."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T-fFs1xEvLsB"
},
"outputs": [],
"source": [
"!cat training/tests/test_memorize_iam.sh"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X-47tUA_YNGe"
},
"source": [
"If you'd like to see what a memorization run looks like,\n",
"flip the `running_memorization` flag to `True`\n",
"and watch the results stream in to W&B.\n",
"\n",
"The cell should run in about ten minutes on a commodity GPU."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GwTEsZwKvLsB"
},
"outputs": [],
"source": [
"%%time\n",
"running_memorization = False\n",
"\n",
"if running_memorization:\n",
" max_epochs = 1000\n",
" loss_criterion = 0.05\n",
" !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPoFCoEcC8SV"
},
"source": [
"# Troubleshooting model speed with the PyTorch Profiler"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpbN-Om2Drf-"
},
"source": [
"Testing code is only half the story here:\n",
"we also need to fix the issues that our tests flag.\n",
"This is the process of troubleshooting.\n",
"\n",
"In this lab,\n",
"we'll focus on troubleshooting model performance issues:\n",
"what do to when your model runs too slowly."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NZzwELPXvLsD"
},
"source": [
"Troubleshooting deep neural networks for speed is challenging.\n",
"\n",
"There are at least three different common approaches,\n",
"each with an increasing level of skill required:\n",
"\n",
"1. Follow best practices advice from others\n",
"([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n",
"[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n",
"2. Take code that runs slowly and use empirical observations to iteratively improve it.\n",
"3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n",
"\n",
"For the full stack deep learning engineer,\n",
"the final level is typically out of reach,\n",
"unless you're specializing in the model performance\n",
"part of the stack in particular.\n",
"\n",
"So we recommend reaching the middle level,\n",
"and this segment of the lab walks through the\n",
"tools that make this easier."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_yp87UrFZ8M"
},
"source": [
"Because neural network training involves GPU acceleration,\n",
"generic Python profiling tools like\n",
"[`py-spy`](https://github.com/benfred/py-spy)\n",
"won't work, and\n",
"we'll need tools specialized for tracing and profiling DNN training."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yspsYVFGEyZm"
},
"source": [
"In general, these tools are for observing what happens while your code is executing:\n",
"_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n",
"\n",
"Because they help us observe the execution in detail,\n",
"they will also help us understand just what is going on during\n",
"a PyTorch training step in greater detail."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YqXq2hKuvLsE"
},
"source": [
"To support profiling and tracing,\n",
"we've added a new argument to `training/run_experiment.py`, `--profile`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z_GMMViWvLsE"
},
"outputs": [],
"source": [
"!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZldoksHPvLsE"
},
"source": [
"As with experiment management, this relies mostly on features of PyTorch Lightning,\n",
"which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n",
"and we just add a few lines of customization:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F2iJ0_A6vLsE"
},
"outputs": [],
"source": [
"!cat training/run_experiment.py | grep args.profile -A 5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Aw3ppgndvLsE"
},
"source": [
"For more on profiling with Lightning, see the\n",
"[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uCAmNW3QEtcD"
},
"source": [
"The cell below runs an epoch of training with tracing and profiling turned on\n",
"and then saves the results locally and to W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t4o3ylDgr46F",
"scrolled": false
},
"outputs": [],
"source": [
"import glob\n",
"\n",
"import torch\n",
"import wandb\n",
"\n",
"from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n",
"\n",
"\n",
"# make it easier to separate these from training runs\n",
"%env WANDB_JOB_TYPE=profile\n",
"\n",
"batch_size = 16\n",
"num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n",
"gpus = 1 # must be run with accelerator\n",
"\n",
"%run training/run_experiment.py --wandb --profile \\\n",
" --max_epochs=1 \\\n",
" --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n",
" --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n",
" --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n",
"\n",
"latest_expt = wandb.run\n",
"\n",
"try: # add execution trace to logged and versioned binaries\n",
" folder = wandb.run.dir\n",
" trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n",
" trace_file = glob.glob(trace_matcher)[0]\n",
" trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n",
" trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n",
" wandb.log_artifact(trace_at)\n",
"except IndexError:\n",
" print(\"trace not found\")\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ePTkS3EqO5tN"
},
"source": [
"We get out a table of statistics in the terminal,\n",
"courtesy of Lightning.\n",
"\n",
"Each row lists an operation\n",
"and and provides information,\n",
"described in the column headers,\n",
"about the time spent on that operation\n",
"across all the training steps we profiled.\n",
"\n",
"With practice, some useful information can be read out from this table,\n",
"but it's better to start from both a less detailed view,\n",
"in the TensorBoard dashboard,\n",
"and a more detailed view,\n",
"using the Chrome Trace viewer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TzV62f3c7-Bi"
},
"source": [
"## High-level statistics from the PyTorch Profiler in TensorBoard"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mNPKXkYw8NWd"
},
"source": [
"Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CbItwuT88eAV"
},
"outputs": [],
"source": [
"your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n",
"\n",
"print(your_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jE_LooMYHFpF"
},
"source": [
"If at any point you run into issues,\n",
"like the description not matching what you observe,\n",
"check out one of our example runs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "za2zybSwIo5C"
},
"outputs": [],
"source": [
"example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n",
"print(example_tensorboard_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xlrhl1n4HYU6"
},
"source": [
"Once the TensorBoard session has loaded up,\n",
"we are dropped into the Overview\n",
"(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n",
"for an example).\n",
"\n",
"In the top center, we see the **GPU Summary** for our system.\n",
"\n",
"In addition to the name of our GPU,\n",
"there are a few configuration details and top-level statistics.\n",
"They are (tersely) documented\n",
"[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MmBhUDgDLhd1"
},
"source": [
"- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n",
"this is effectively a coarse \"version number\" for your GPU hardware.\n",
"It indexes which features are available,\n",
"with more advanced features being available only at higher compute capabilities.\n",
"It does not directly index the speed or memory of the GPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "voUgT6zuLyi0"
},
"source": [
"- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yl-IndtXE4b4"
},
"source": [
"- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n",
"for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n",
"Tensor Cores.\n",
"If you're running on an older GPU without Tensor Cores,\n",
"you should consider upgrading.\n",
"If you're running a more recent GPU but not seeing Tensor Core usage,\n",
"you should switch to single precision floating point numbers,\n",
"which Tensor Cores are specialized on."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XxcUf0bBNXy_"
},
"source": [
"- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n",
"at a lower level than just whether something is running at all,\n",
"as in utilization.\n",
"Unlike utilization, reaching 100% is not generally feasible\n",
"and sometimes not desirable.\n",
"Increasing these numbers requires expertise in\n",
"CUDA programming, so we'll target utilization instead."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A88pQn4YMMKc"
},
"source": [
"- **Execution Summary**: This table and pie chart indicates\n",
"how much time within a profiled step\n",
"was spent in each category.\n",
"The value for \"kernel\" execution here\n",
"is equal to the GPU utilization,\n",
"and we want that number to be as close to 100%\n",
"as possible.\n",
"This summary helps us know which\n",
"other operations are taking time,\n",
"like memory being copied between CPU and GPU (`memcpy`)\n",
"or `DataLoader`s executing on the CPU,\n",
"so we can decide where the bottleneck is."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6qjW1RlTQRPv"
},
"source": [
"At the very bottom, you'll find a\n",
"**Performance Recommendation**\n",
"tab that sometimes suggests specific methods for improving performance.\n",
"\n",
"If this tab makes suggestions, you should certainly take them!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pWY5AhrcRQmJ"
},
"source": [
"For more on using the profiler in TensorBoard,\n",
"including some of the other, more detailed views\n",
"available view the \"Views\" dropdown menu, see\n",
"[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mQwrPY_H77H8"
},
"source": [
"## Going deeper with the Chrome Trace Viewer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yhwo7fslvLsH"
},
"source": [
"So far, we've seen summary-level information about our training steps\n",
"in the table from Lightning and in the TensorBoard Overview.\n",
"These give aggregate statistics about the computations that occurred,\n",
"but understanding how to interpret those statistics\n",
"and use them to speed up our networks\n",
"requires understanding just what is\n",
"happening in our training step.\n",
"\n",
"Fundamentally,\n",
"all computations are processes that unfold in time.\n",
"\n",
"If we want to really understand our training step,\n",
"we need to display it that way:\n",
"what operations were occurring,\n",
"on both the CPU and GPU,\n",
"at each moment in time during the training step.\n",
"\n",
"This information on timing is collected in the trace.\n",
"One of the best tools for viewing the trace over time\n",
"is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wUkZItxYc20A"
},
"source": [
"Let's tour the trace we just logged\n",
"with an aim to really understanding just\n",
"what is happening when we call\n",
"`training_step`\n",
"and by extension `.forward`, `.backward`, and `optimizer.step`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9w9F2UA7Qctg"
},
"source": [
"The Chrome Trace Viewer is built into W&B,\n",
"so we can view our traces in their interface.\n",
"\n",
"The cell below embeds the trace inside the notebook,\n",
"but you may wish to open it separately,\n",
"with the \"Open page\" button or by navigating to the URL,\n",
"so that you can interact with it\n",
"as you read the description below.\n",
"Display directly on W&B is also a bit less temperamental\n",
"than display on W&B inside a notebook.\n",
"\n",
"Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n",
"so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n",
"It also can interact poorly with browser extensions (e.g. ad blockers),\n",
"so you may need to deactivate them temporarily in order to see it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OMUs4aby6Rfd"
},
"outputs": [],
"source": [
"trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n",
"trace_url = trace_files_url + \"training_step.pt.trace.json\"\n",
"\n",
"example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n",
"\n",
"print(trace_url)\n",
"IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qNVpGeQtQjMG"
},
"source": [
"> **Heads up!** We're about to do a tour of the\n",
"> precise details of the tracing information logged\n",
"> during the execution of the training code.\n",
"> The only way to learn how to troubleshoot model performance\n",
"> empirically is to look at the details,\n",
"> but the details depend on the precise machine being used\n",
"> -- GPU and CPU and RAM.\n",
"> That means even within Colab,\n",
"> these details change from session to session.\n",
"> So if you don't observe a phenomenon or feature\n",
"> described in the tour below, check out\n",
"> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n",
"> on W&B while reading through the next section of the lab,\n",
"> and return to your trace once you understand the trace viewer better at the end.\n",
"> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n",
"> can sometimes be a bit janky."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kXMcBhnCgdN_"
},
"source": [
"This trace reveals, in nanosecond-level detail,\n",
"what's going on inside of a `training_step`\n",
"on both the GPU and the CPU.\n",
"\n",
"Time is on the horizontal axis.\n",
"Colored bars represent method calls,\n",
"and the methods called by a method are placed underneath it vertically,\n",
"a visualization known as an\n",
"[icicle chart](https://www.brendangregg.com/flamegraphs.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "67BsNzDfVIeg"
},
"source": [
"Let's orient ourselves with some gross features:\n",
"the forwards pass,\n",
"GPU kernel execution,\n",
"the backwards pass,\n",
"and the optimizer step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBEFgtRCKqrh"
},
"source": [
"### The forwards pass"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5nYhiWesVMjK"
},
"source": [
"Type in `resnet` to the search bar in the top-right.\n",
"\n",
"This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n",
"\n",
"It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n",
"\n",
"You can click the arrows next to that tile to partially collapse these blocks.\n",
"\n",
"Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n",
"It should be at roughly the same height.\n",
"\n",
"Clear the search bar so that the trace is in color.\n",
"Zoom in on the area of the forwards pass\n",
"using the \"zoom\" tool in the floating toolbar,\n",
"so you can see more detail.\n",
"The zoom tool is indicated by a two-headed arrow\n",
"pointing into and out of the screen.\n",
"\n",
"Switch to the \"drag\" tool,\n",
"represented by a four-headed arrow.\n",
"Click-and-hold to use this tool to focus\n",
"on different parts of the timeline\n",
"and click on the individual colored boxes\n",
"to see details about a particular method call.\n",
"\n",
"As we go down in the icicle chart,\n",
"we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n",
"to much more precise `cudnn` and `cuda` operations\n",
"(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n",
"\n",
"`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n",
"is the tensor math library in PyTorch\n",
"that links to specific backends like `cudnn`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fq181ybIvLsH"
},
"source": [
"### GPU kernel execution"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IbkWp5aKvLsH"
},
"source": [
"Towards the bottom, you should see a section labeled \"GPU\".\n",
"The label appears on the far left.\n",
"\n",
"Within it, you'll see one or more \"`stream`s\".\n",
"These are units of work on a GPU,\n",
"akin loosely to threads on the CPU.\n",
"\n",
"When there are colored bars in this area,\n",
"the GPU is doing work of some kind.\n",
"The fraction of this bar that is filled in with color\n",
"is the same as the \"GPU Utilization %\" we've seen previously.\n",
"So the first thing to visually assess\n",
"in a trace view of PyTorch code\n",
"is what fraction of this area is filled with color.\n",
"\n",
"In CUDA, work is queued up to be\n",
"placed into streams and completed, on the GPU,\n",
"in a distributed and asynchronous manner.\n",
"\n",
"The selection of which work to do\n",
"is happening on the CPU,\n",
"and that's what we were looking at above.\n",
"\n",
"The CPU and the GPU have to work together to coordinate\n",
"this work.\n",
"\n",
"Type `cuda` into the search bar and you'll see these coordination operations happening:\n",
"`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n",
"\n",
"Running the same PyTorch model\n",
"with the same high level operations like `Conv2d` in different versions of PyTorch,\n",
"on different GPUs, and even on tensors of different sizes will result\n",
"in different choices of concrete kernel operation,\n",
"e.g. different matrix multiplication algorithms.\n",
"\n",
"Type `sync` into the search bar and you'll see places where either work on the GPU\n",
"or work on the CPU needs to await synchronization,\n",
"e.g. copying data from the CPU to the GPU\n",
"or the CPU waiting to decide what to do next\n",
"on the basis of the contents of a tensor.\n",
"\n",
"If you see a \"sync\" block above an area\n",
"where the stream on the GPU is empty,\n",
"you've got a performance bottleneck due to synchronization\n",
"between the CPU and GPU.\n",
"\n",
"To resolve the bottleneck,\n",
"head up the icicle chart until you reach the recognizable\n",
"PyTorch modules and operations.\n",
"Find where they are called in your PyTorch module.\n",
"That's a good place to review your code to understand why the synchronization is happening\n",
"and removing it if it's not necessary."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XeMPbu_jvLsI"
},
"source": [
"### The backwards pass\n",
"\n",
"Type in `backward` into the search bar.\n",
"\n",
"This will highlight components of our backwards pass.\n",
"\n",
"If you read it from left to right,\n",
"you'll see that it begins by calculating the loss\n",
"(`NllLoss2DBackward` in the search bar if you can't find it)\n",
"and ends by doing a `ConvolutionBackward`,\n",
"the first layer of the ResNet.\n",
"It is, indeed, backwards.\n",
"\n",
"Like the forwards pass,\n",
"the backwards pass also involves the CPU\n",
"telling the GPU which kernels to run.\n",
"It's typically run in a separate\n",
"thread from the forwards pass,\n",
"so you'll see it separated out from the forwards pass\n",
"in the trace viewer.\n",
"\n",
"Generally, there's no need to specifically optimize the backwards pass --\n",
"removing bottlenecks in the forwards pass results in a fast backwards pass.\n",
"\n",
"One reason why is that these two passes are just\n",
"\"transposes\" of one another,\n",
"so they share a lot of properties,\n",
"and bottlenecks in one become bottlenecks in the other.\n",
"We can choose to optimize either one of the two.\n",
"But the forwards pass is under our direct control,\n",
"so it's easier for us to reason about.\n",
"\n",
"Another reason is that the forwards pass is more likely to have bottlenecks.\n",
"The forwards pass is a dynamic process,\n",
"with each line of Python adding more to the compute graph.\n",
"Backwards passes, on the other hand, use a static compute graph,\n",
"the one just defined by the forwards pass,\n",
"so more optimizations are possible."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gWiDw0vCvLsI"
},
"source": [
"### The optimizer step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ndfkzEdnvLsI"
},
"source": [
"Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n",
"\n",
"As with the two passes,\n",
"we are still using the CPU\n",
"to launch kernels on the GPU.\n",
"But now the CPU is looping,\n",
"in Python, over the parameters\n",
"and applying the ADAM updates rules to each.\n",
"\n",
"We now know enough to see that\n",
"this is not great for our GPU utilization:\n",
"there are many areas of gray\n",
"in between the colored bars\n",
"in the GPU stream in this area.\n",
"\n",
"In the time it takes CUDA to multiply\n",
"thousands of numbers,\n",
"Python has not yet finished cleaning up\n",
"after its request for that multiplication.\n",
"\n",
"As of writing in August 2022,\n",
"more efficient optimizers are not a stable part of PyTorch (v1.12), but\n",
"[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n",
"and stable implementations outside of PyTorch.\n",
"The standard implementations are in\n",
"[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n",
"not to be confused with the\n",
"[Apex Optimizers Project](https://www.apexoptimizers.com/),\n",
"which is a collection of fitness-themed cheetah NFTs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WX0jxeafvLsI"
},
"source": [
"## Take-aways for PyTorch performance bottleneck troubleshooting"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CugD-bK2vLsI"
},
"source": [
"Our goal here was to learn some basic principles and tools for bottlenecking\n",
"the most common issues and the lowest-hanging fruit in PyTorch code."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SwHwJkVMHYGA"
},
"source": [
"\n",
"Here's an overview in terms of a \"host\",\n",
"generally the CPU,\n",
"and a \"device\", here the GPU.\n",
"\n",
"- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n",
"- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n",
" - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n",
"- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n",
"stores information about the contents of tensors,\n",
"not just their metadata."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gntx28p9cBP5"
},
"source": [
"Towards that goal, we viewed the trace to get an understanding of\n",
"what's going on inside a PyTorch training step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AKvZGPnkeXvq"
},
"source": [
"Here's what we've means in terms of troubleshooting bottlenecks.\n",
"\n",
"We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n",
"before the previous kernel finishes.\n",
"\n",
"Ideally, the CPU is actually getting far _ahead_ of execution\n",
"on the GPU.\n",
"If the CPU makes it all the way through the backwards pass before the GPU is done,\n",
"that's great!\n",
"The GPU(s) are the expensive part,\n",
"and it's easy to use multiprocessing so that\n",
"the CPU has other things to do.\n",
"\n",
"This helps explain at least one common piece of advice:\n",
"the larger our batches are,\n",
"the more work the GPU has to do for the same work done by the CPU,\n",
"and so the better our utilization will be."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMztpa-TccH4"
},
"source": [
"We operationalize our desire to never be waiting on the CPU with a simple metric:\n",
"**100% GPU utilization**, meaning a kernel is running at all times.\n",
"\n",
"This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n",
"\n",
"You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7kYBygfScR6z"
},
"source": [
"Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n",
"1. **The CPU is too weak**.\n",
"Because so much of the discussion around DNN performance is about GPUs,\n",
"it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n",
"_Resolution_:\n",
"Use nice CPUs, like\n",
"[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n",
"2. **Too much Python during the `training_step`**.\n",
"Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n",
"that takes longer than running an entire layer.\n",
"_Resolution_:\n",
"Look for low utilization areas of the trace\n",
"and check what's happening on the CPU at that time\n",
"and carefully review the Python code being executed.\n",
"3. **Unnecessary Host/Device synchronization**.\n",
"If one of your operations depends on the values in a tensor,\n",
"like `if xs.mean() >= 0`,\n",
"you'll induce a synchronization between\n",
"the host and the device and possibly lead\n",
"to an expensive and slow copy of data.\n",
"_Resolution_:\n",
"Replace these operations as much as possible\n",
"with purely array-based calculations.\n",
"4. **Bottlenecking on the DataLoader**.\n",
"In addition to coordinating the work on the GPU,\n",
"CPUs often perform heavy data operations,\n",
"including communication over the network\n",
"and writing to/reading from disk.\n",
"These are generally done in parallel to the forwards\n",
"and backwards passes,\n",
"but if they don't finish before that happens,\n",
"they will become the bottleneck.\n",
"_Resolution_:\n",
"Get better hardware for compute,\n",
"memory, and network.\n",
"For software solutions, the answer \n",
"is a bit more complex and application-dependent.\n",
"For generic tips, see\n",
"[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n",
"in the PyTorch forums.\n",
"For techniques in computer vision, see\n",
"[the FFCV library](https://github.com/libffcv/ffcv)\n",
"and for techniques in NLP, see e.g.\n",
"[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n",
"and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i2WYS8bQvLsJ"
},
"source": [
"### Further steps in making DNNs go brrrrrr"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T0wW2_lRKfY1"
},
"source": [
"It's important to note that utilization\n",
"is just an easily measured metric\n",
"that can reveal common bottlenecks.\n",
"Having high utilization does not automatically mean\n",
"that your performance is fully optimized.\n",
"\n",
"For example,\n",
"synchronization events between GPUs\n",
"are counted as kernels,\n",
"so a deadlock during distributed training\n",
"can show up as 100% utilization,\n",
"despite literally no useful work occurring.\n",
"\n",
"Just switching to \n",
"double precision floats, `--precision=64`,\n",
"will generally lead to much higher utilization.\n",
"The GPU operations take longer\n",
"for roughly the same amount of CPU effort,\n",
"but the added precision brings no benefit.\n",
"\n",
"In particular, it doesn't make for models\n",
"that perform better on our correctness metrics,\n",
"like loss and accuracy.\n",
"\n",
"Another useful yardstick to add\n",
"to utilization is examples per second,\n",
"which incorporates how quickly the model is processing data examples\n",
"and calculating gradients.\n",
"\n",
"But really,\n",
"the gold star is _decrease in loss per second_.\n",
"This metric connects model design choices\n",
"and hyperparameters with purely engineering concerns,\n",
"so it disrespects abstraction barriers\n",
"and doesn't generally lead to actionable recommendations,\n",
"but it is, in the end, the real goal:\n",
"make the loss go down faster so we get better models sooner."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EFzPsplfdo_o"
},
"source": [
"For PyTorch internals abstractly,\n",
"see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n",
"\n",
"For more on performance considerations in PyTorch,\n",
"see [Horace He's blog post](https://horace.io/brrr_intro.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RFx-OhF837Bp"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yq6-S6TC38AY"
},
"source": [
"### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n",
"\n",
"One of the most important features for making\n",
"PyTorch run quickly is the\n",
"`MultiprocessingDataLoader`,\n",
"which executes batching of data in a separate process\n",
"from the forwards and backwards passes.\n",
"\n",
"By default in PyTorch,\n",
"this feature is actually turned off,\n",
"via the `DataLoader` argument `num_workers`\n",
"having a default value of `0`,\n",
"but we set the `DEFAULT_NUM_WORKERS`\n",
"to a value based on the number of CPUs\n",
"available on the system running the code.\n",
"\n",
"Re-run the profiling cell,\n",
"but set `num_workers` to `0`\n",
"to turn off multiprocessing.\n",
"\n",
"Compare and contrast the two traces,\n",
"both for total runtime\n",
"(see the time axis at the top of the trace)\n",
"and for utilization.\n",
"\n",
"If you're unable to run the profiles,\n",
"see the results\n",
"[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n",
"which juxtaposes two traces,\n",
"with in-process dataloading on the left and\n",
"multiprocessing dataloading on the right."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5D39w0gXAiha"
},
"source": [
"### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T2i_a5eVeIoA"
},
"source": [
"The file below incorrectly implements and then incorrectly tests\n",
"a simple PyTorch utility for adding five to every entry of a tensor\n",
"and then calculating the sum.\n",
"\n",
"Even worse, it does it with horrible style!\n",
"\n",
"The cells below apply our linting checks\n",
"(after automatically fixing the formatting)\n",
"and run the test.\n",
"\n",
"Fix all of the lints,\n",
"implement the function correctly,\n",
"and then implement some basic tests."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wSon2fB5VVM_"
},
"source": [
"- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n",
"- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n",
"- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n",
"- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n",
"- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aYiRvU4HA84t"
},
"outputs": [],
"source": [
"%%writefile training/fixme.py\n",
"import torch\n",
"from training import run_experiment\n",
"from numpy import *\n",
"import random\n",
"from pathlib import Path\n",
"\n",
"\n",
"\n",
"\n",
"def add_five_and_sum(tensor):\n",
" # this function is not implemented right,\n",
" # but it's supposed to add five to all tensor entries and sum them up\n",
" return 1\n",
"\n",
"def test_add_five_and_sum():\n",
" # and this test isn't right either! plus this isn't exactly a docstring\n",
" all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n",
" all_fives = 5 * all_ones\n",
" assert False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EXJpmvuzT1w0"
},
"outputs": [],
"source": [
"!pre-commit run black --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SRO-oJfdUrcQ"
},
"outputs": [],
"source": [
"!cat training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jM8NHxVbSEQD"
},
"outputs": [],
"source": [
"!pre-commit run --files training/fixme.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kj0VMBSndtkc"
},
"outputs": [],
"source": [
"!pytest training/fixme.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab05_troubleshooting.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
================================================
FILE: lab08/notebooks/lab06_data.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FlH0lCOttCs5"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZUPRHaeetRnT"
},
"source": [
"# Lab 06: Data Annotation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bry3Hr-PcgDs"
},
"source": [
"### What You Will Learn\n",
"\n",
"- How the `IAM` handwriting dataset is structured on disk and how it is processed into an ML-friendly format\n",
"- How to setup a [Label Studio](https://labelstud.io/) data annotation server\n",
"- Just how messy data really is"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vs0LXXlCU6Ix"
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZkQiK7lkgeXm"
},
"source": [
"If you're running this notebook on Google Colab,\n",
"the cell below will run full environment setup.\n",
"\n",
"It should take about three minutes to run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sVx7C7H0PIZC"
},
"outputs": [],
"source": [
"lab_idx = 6\n",
"\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" # needed for inline plots in some contexts\n",
" %matplotlib inline\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DpvaHz9TEGwV"
},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gsXpeXi2EGwV"
},
"outputs": [],
"source": [
"from IPython.display import IFrame\n",
"\n",
"\n",
"IFrame(src=\"https://fsdl.me/2022-lab-06-video-embed\", width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XTkKzEMNR8XZ"
},
"source": [
"# `IAMParagraphs`: From annotated data to a PyTorch `Dataset`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3mQLbjuiwZuj"
},
"source": [
"We've used the `text_recognizer.data` submodule\n",
"and its `LightningDataModule`s -- `IAMLines` and `IAMParagraphs`\n",
"for lines and paragraphs of handwritten text\n",
"from the\n",
"[IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database).\n",
"\n",
"These classes convert data from a database-friendly format\n",
"designed for storage and transfer into the\n",
"format our DNNs expect:\n",
"PyTorch `Tensor`s.\n",
"\n",
"In this section,\n",
"we'll walk through that process in detail.\n",
"\n",
"In the following section,\n",
"we'll see how data\n",
"goes from signals measured in the world\n",
"to the format we consume here."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "499c23a6"
},
"source": [
"## Dataset structure on disk"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a3438d2e"
},
"source": [
"We begin by downloading the raw data to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "18900eec"
},
"outputs": [],
"source": [
"from text_recognizer.data.iam import IAM\n",
"\n",
"iam = IAM()\n",
"iam.prepare_data()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a332f359"
},
"source": [
"The `IAM` dataset is downloaded as zip file\n",
"and then unzipped:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d6c44266"
},
"outputs": [],
"source": [
"from text_recognizer.metadata.iam import DL_DATA_DIRNAME\n",
"\n",
"\n",
"iam_dir = DL_DATA_DIRNAME\n",
"!ls {iam_dir}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8463c2d1"
},
"source": [
"The unzipped dataset is not simple a flat directory of files.\n",
"\n",
"Instead, there are a number of subfolders,\n",
"each of which contains a particular type of data or metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "536924f7"
},
"outputs": [],
"source": [
"iamdb = iam_dir / \"iamdb\"\n",
"\n",
"!du -h {iamdb}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b745a594"
},
"source": [
"For example, the `task` folder contains metadata about canonical dataset splits:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "84c21f75"
},
"outputs": [],
"source": [
"!find {iamdb / \"task\"} | grep \"\\\\.txt$\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mEb0Pdm4vIHe"
},
"source": [
"We find the images of handwritten text in the `forms` folder.\n",
"\n",
"An individual \"datapoint\" in `IAM` is a \"form\",\n",
"because the humans whose hands wrote the text were prompted to write on \"forms\",\n",
"as below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "945d5e3a"
},
"outputs": [],
"source": [
"from IPython.display import Image\n",
"\n",
"\n",
"form_fn, = !find {iamdb}/forms | grep \".jpg$\" | sort | head -n 1\n",
"\n",
"print(form_fn)\n",
"Image(filename=form_fn, width=\"360\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b9e9e384"
},
"source": [
"Meanwhile, the `xml` files contain the data annotations,\n",
"written out as structured text:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6add5c5a"
},
"outputs": [],
"source": [
"xml_fn, = !find {iamdb}/xml | grep \"\\.xml$\" | sort | head -n 1\n",
"\n",
"!cat {xml_fn} | grep -A 100 \"handwritten-part\" | grep \"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MX9n-Zed8G_T"
},
"source": [
"# Lab 08: Monitoring"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tv8O0V0EV09z"
},
"source": [
"## What You Will Learn\n",
"\n",
"- How to add user feedback and model monitoring to a Gradio-based app\n",
"- How to analyze this logged information to uncover and debug model issues\n",
"- Just how large the gap between benchmark data and data from users can be, and what to do about it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "45D6GuSwvT7d"
},
"outputs": [],
"source": [
"lab_idx = 8\n",
"\n",
"\n",
"if \"bootstrap\" not in locals() or bootstrap.run:\n",
" # path management for Python\n",
" pythonpath, = !echo $PYTHONPATH\n",
" if \".\" not in pythonpath.split(\":\"):\n",
" pythonpath = \".:\" + pythonpath\n",
" %env PYTHONPATH={pythonpath}\n",
" !echo $PYTHONPATH\n",
"\n",
" # get both Colab and local notebooks into the same state\n",
" !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n",
" import bootstrap\n",
"\n",
" %matplotlib inline\n",
"\n",
" # change into the lab directory\n",
" bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n",
"\n",
" bootstrap.run = False # change to True re-run setup\n",
"\n",
"!pwd\n",
"%ls"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cUdTJE54V09z"
},
"source": [
"### Follow along with a video walkthrough on YouTube:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4J9hDxNsV09z"
},
"outputs": [],
"source": [
"from IPython.display import IFrame\n",
"\n",
"\n",
"IFrame(src=\"https://fsdl.me/2022-lab-08-video-embed\", width=\"100%\", height=720)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zvi49122ho0r"
},
"source": [
"# Basic user feedback with `gradio`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "56y2r9IYkY7A"
},
"source": [
"On top of the basic health check and event logging\n",
"necessary for any distributed system\n",
"(provided for our application by\n",
"[AWS CloudWatch](https://aws.amazon.com/cloudwatch/),\n",
"which is collects logs from EC2 and Lambda instances),\n",
"ML-powered applications need specialized monitoring solutions.\n",
"\n",
"In particular, we want to give users a way\n",
"to report issues or indicate their level of satisfaction\n",
"with the model.\n",
"\n",
"The UI-building framework we're using, `gradio`,\n",
"comes with user feedback, under the name \"flagging\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wXq4jcjCkNap"
},
"source": [
"To see how this works, we first spin up our front end,\n",
"pointed at the AWS Lambda backend,\n",
"as in\n",
"[the previous lab](https://fsdl.me/lab07-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rAZrYRnSiMER"
},
"outputs": [],
"source": [
"from app_gradio import app\n",
"\n",
"\n",
"lambda_url = \"https://3akxma777p53w57mmdika3sflu0fvazm.lambda-url.us-west-1.on.aws/\"\n",
"\n",
"backend = app.PredictorBackend(url=lambda_url)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "STXn1XaHkU42"
},
"source": [
"And adding user feedback collection\n",
"is as easy as passing `flagging=True`.\n",
"\n",
"> The `flagging` argument is here being given to\n",
"code from the FSDL codebase, `app.make_frontend`,\n",
"but you can just pass\n",
"`flagging=True` directly\n",
"to the `gradio.Interface` class.\n",
"In between in our code,\n",
"we have a bit of extra logic\n",
"so that we can support\n",
"multiple different storage backends for logging flagged data.\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mxZQRklXV091"
},
"source": [
"Run the cell below to create a frontend\n",
"(accessible on a public Gradio URL and inside the notebook)\n",
"and observe the new \"flagging\" buttons underneath the outputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Kgygx8d5ip9V"
},
"outputs": [],
"source": [
"frontend = app.make_frontend(fn=backend.run, flagging=True)\n",
"frontend.launch(share=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zV2tu8HTk242"
},
"source": [
"Click one of the buttons to trigger flagging.\n",
"\n",
"It doesn't need to be a legitimate issue with the model's outputs.\n",
"\n",
"Instead of just submitting one of the example images,\n",
"you might additionally use the image editor\n",
"(pencil button on uploaded images)\n",
"to crop it."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gJV79PDIk-4S"
},
"source": [
"Flagged data is stored on the server's local filesystem,\n",
"by default in the `flagged/` directory\n",
"as a `.csv` file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RbCcCxvHi2jh"
},
"outputs": [],
"source": [
"!ls flagged"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Koh1SP9NlA6y"
},
"source": [
"We can load the `.csv` with `pandas`,\n",
"the Python library for handling tabular data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OJCnIsfEjC05"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"log_path = Path(\"flagged\") / \"log.csv\"\n",
"\n",
"flagged_df = None\n",
"if log_path.exists():\n",
" flagged_df = pd.read_csv(log_path, quotechar=\"'\") # quoting can be painful for natural text data\n",
" flagged_df = flagged_df.dropna(subset=[\"Handwritten Text\"]) # drop any flags without an image\n",
"\n",
"flagged_df"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KZieT-FgldKa"
},
"source": [
"Notice that richer data, like images, is stored with references --\n",
"here, the names of local files.\n",
"\n",
"This is a common pattern:\n",
"binary data doesn't go in the database,\n",
"only pointers to binary data.\n",
"\n",
"We can then read the data back to analyze our model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gWG3T3Qql_99"
},
"outputs": [],
"source": [
"from IPython.display import display\n",
"\n",
"from text_recognizer.util import read_image_pil\n",
"\n",
"\n",
"if flagged_df is not None:\n",
" row = flagged_df.iloc[-1]\n",
" print(row[\"output\"])\n",
" display(read_image_pil(Path(\"flagged\") / row[\"Handwritten Text\"]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0gIpfRMFl9_D"
},
"source": [
"We encourage you to play around with the model for a bit,\n",
"uploading your own images.\n",
"\n",
"This is an important step in understanding your model\n",
"and your domain --\n",
"especially when you're familiar with the data types involved.\n",
"\n",
"But even when you are,\n",
"we expect you'll quickly find\n",
"that you run out of ideas\n",
"for different ways to probe your model.\n",
"\n",
"To really learn more about your model,\n",
"you'll need some actual users.\n",
"\n",
"In small projects,\n",
"these can be other team members who are less enmeshed\n",
"in the details of model development and data munging.\n",
"\n",
"But to create something that can appeal to a broader set of users,\n",
"you'll want to collect feedback from your potential userbase."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RHArpXNyRtg7"
},
"source": [
"# Debugging production models with `gantry`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hbGCYG0BmvdE"
},
"source": [
"Unfortunately, this aspect of model development\n",
"is particularly challenging to replicate in\n",
"a course setting, especially a MOOC --\n",
"where do these users come from?\n",
"\n",
"As part of the 2022 edition of the course, we've\n",
"[been running a text recognizer application](https://fsdl-text-recognizer.ngrok.io)\n",
"and collecting user feedback on it.\n",
"\n",
"Rather than saving user feedback data locally,\n",
"as with the CSV logger above,\n",
"we've been sending that data to\n",
"[Gantry](https://gantry.io/),\n",
"a model monitoring and continual learning tool.\n",
"\n",
"That's because local logging is a very bad idea:\n",
"as logs grow, the storage needs and read/write time grow,\n",
"which unduly burdens the frontend server.\n",
"\n",
"The `gradio` library supports logging of user-flagged data\n",
"to arbitrary backends via\n",
"`FlaggingCallback`s.\n",
"\n",
"So there's some new elements to the codebase:\n",
"most importantly here, a `GantryImageToTextLogger`\n",
"that inherits from `gradio.FlaggingCallback`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pptT76DWmlB0"
},
"outputs": [],
"source": [
"from app_gradio import flagging\n",
"\n",
"\n",
"print(flagging.GantryImageToTextLogger.__init__.__doc__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-3HevRM2YkbZ"
},
"source": [
"If we add this `Callback` to our setup --\n",
"and add a Gantry API key to our environment --\n",
"then we can start sending data to Gantry's service."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UHnIV0e_a9o6"
},
"outputs": [],
"source": [
"app.make_frontend??"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jJcfaWNpRzJF"
},
"source": [
"The short version of how the logging works:\n",
"we upload flagged images to S3 for storage (`GantryImageToTextLogger._to_s3`)\n",
"and send the URL to Gantry along with the outputs (`GantryImageToTextLogger._to_gantry`)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uviSZDTma1RT"
},
"source": [
"Below, we'll download that data\n",
"and look through it in the notebook,\n",
"using typical Python data analysis tools,\n",
"like `pandas` and `seaborn`.\n",
"\n",
"By analogy to\n",
"[EDA](https://en.wikipedia.org/wiki/Exploratory_data_analysis),\n",
"consider this an \"exploratory model analysis\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LFxypmESXESL"
},
"outputs": [],
"source": [
"import gantry.query as gq\n",
"\n",
"\n",
"read_only_key = \"VpPfHPDSk9e9KKAgbiHBh7mqF_8\"\n",
"gq.init(api_key=read_only_key)\n",
"\n",
"gdf = gq.query( # we query Gantry's service with the following parameters:\n",
" application=\"fsdl-text-recognizer\", # which tracked application should we draw from?\n",
" # what time period should we pull data from? here, the first two months the app was up\n",
" start_time=\"2022-07-01T07:00:00.000Z\",\n",
" end_time=\"2022-09-01T06:59:00.000Z\",\n",
")\n",
"\n",
"raw_df = gdf.fetch()\n",
"df = raw_df.dropna(axis=\"columns\", how=\"all\") # remove any irrelevant columns\n",
"df = df[df[\"tags.env\"] == \"dev\"] # filter down to info logged from the development environment\n",
"print(\"number of rows:\", len(df))\n",
"df = df.drop_duplicates(keep=\"first\", subset=\"inputs.image\") # remove repeated reports, eg of example images\n",
"print(\"number of unique rows:\", len(df))\n",
"\n",
"print(\"\\ncolumns:\")\n",
"df.columns"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bN6YNmnCV094"
},
"source": [
"We'll walk through what each of these columns means,\n",
"but the three most important are the ones we logged directly from the application:\n",
"`flag`s, `input.image`s, and `output_text`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c8SEwiAXV094"
},
"outputs": [],
"source": [
"main_columns = [column for column in df.columns if \"(\" not in column] # derived columns have a \"function call\" in the name\n",
"main_columns"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i8HfH-BIV094"
},
"source": [
"If you're interested in playing\n",
"around with the data yourself\n",
"in Gantry's UI,\n",
"as we do in the\n",
"[video walkthrough for the lab](https://fsdl.me/2022-lab-08-video),\n",
"you'll need a Gantry account.\n",
"\n",
"Gantry is currently in closed beta.\n",
"Unlike training or experiment management,\n",
"model monitoring and continual learning\n",
"is at the frontier of applied ML,\n",
"so tooling is just starting to roll out.\n",
"\n",
"FSDL students are invited to this beta and\n",
"[can create a \"read-only\" account here](https://gantry.io/fsdl-signup)\n",
"so they can view the data in the UI\n",
"and explore it themselves.\n",
"\n",
"As an early startup,\n",
"Gantry is very interested in feedback\n",
"from practitioners!\n",
"So if you do try out the Gantry UI,\n",
"send any impressions, bug reports, or ideas to\n",
"`support@gantry.io`\n",
"\n",
"This is also a chance for you\n",
"to influence the development\n",
"of a new tool that could one day\n",
"end up at the center of continual learning\n",
"workflows --\n",
"as when\n",
"[FSDL students in spring 2019 got a chance to be early users of W&B](https://www.youtube.com/watch?t=1468&v=Eiz1zcqrqw0&feature=youtu.be&ab_channel=FullStackDeepLearning)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RmTFHvxHi4el"
},
"source": [
"## Basic stats and behavioral monitoring"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hYSQ0r7eV094"
},
"source": [
"We start by just getting some basic statistics.\n",
"\n",
"For example, we can get descriptive statistics for\n",
"the information we've logged."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fb3BMn7gfQRI"
},
"outputs": [],
"source": [
"df[\"feedback.flag\"].describe()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9OseYhc1Q8i"
},
"source": [
"Note that the format we're working with is the `pandas.DataFrame` --\n",
"a standard format for tables in Python.\n",
"\n",
"`pandas` can be\n",
"[very tricky](https://github.com/chiphuyen/just-pandas-things).\n",
"\n",
"It's not so bad when doing exploratory analysis like this,\n",
"but take care when using it in production settings!\n",
"\n",
"If you'd like to learn more `pandas`,\n",
"[Brandon Rhodes's `pandas` tutorial from PyCon 2015](https://www.youtube.com/watch?v=5JnMutdy6Fw&ab_channel=PyCon2015)\n",
"is still one of the best introductions,\n",
"even though it's nearly a decade old."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eG15SMkgV095"
},
"source": [
"`pandas` objects support sampling with `.sample`,\n",
"which is useful for quick \"spot-checking\" of data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FZ5BRRqjc1Of"
},
"outputs": [],
"source": [
"df[\"feedback.flag\"].sample(10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w3rZaYwSzu-D"
},
"source": [
"Unlike many other kinds of applications,\n",
"toxic and offensive behavior is\n",
"one of the most critical potential issues with\n",
"many ML models,\n",
"from\n",
"[generative models like GPT-3](https://www.middlebury.edu/institute/sites/www.middlebury.edu.institute/files/2020-09/gpt3-article.pdf)\n",
"to even humble\n",
"[image labeling models](https://archive.nytimes.com/bits.blogs.nytimes.com/2015/07/01/google-photos-mistakenly-labels-black-people-gorillas/).\n",
"\n",
"So ML models, especially when newly deployed\n",
"or when encountering new user bases,\n",
"need careful supervision."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-CbdSz0hzze7"
},
"source": [
"We use a\n",
"[Gantry tool called Projections](https://docs.gantry.io/en/stable/guides/projections.html)\n",
"to apply the NLP models from the\n",
"[`detoxify` suite](https://github.com/unitaryai/detoxify),\n",
"which score text for features like obscenity and identity attacks,\n",
"to our model's outputs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1Z4lsgRcpQql"
},
"source": [
"To get a quick plot of the resulting values,\n",
"we can use the `pandas` built-in interface\n",
"to `matplotlib`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9UbBg947fAsh"
},
"outputs": [],
"source": [
"df.plot(y=\"detoxify.obscene(outputs.output_text)\", kind=\"hist\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qxiIXGf0pVd5"
},
"source": [
"Without context, this chart isn't super useful --\n",
"is a score of `obscene=0.12` bad?\n",
"\n",
"We need a baseline!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UbOeOkzQgBDE"
},
"source": [
"Once the model is stable in production,\n",
"we can compare the values across time --\n",
"grouping or filtering production data by timestamp.\n",
"\n",
"Here, for this first version of the model,\n",
"we compare the results here with the results on the test data,\n",
"which was also ingested with `gantry`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ooa-Al48f_au"
},
"outputs": [],
"source": [
"test_df = raw_df.dropna(axis=\"columns\", how=\"all\") # remove any irrelevant columns\n",
"test_df = test_df[test_df[\"tags.env\"] == \"test\"] # filter down to info logged from the test environment\n",
"\n",
"test_df.sample(10) # show a sample"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TssF7sSX1Q8k"
},
"source": [
"To compare the two `DataFrame`s,\n",
"we `concat`enate them together\n",
"and add in some metadata\n",
"identifying where the observations came from.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oXWqfOdfgi4o"
},
"outputs": [],
"source": [
"test_df[\"environment\"] = \"test\"\n",
"df[\"environment\"] = \"prod\"\n",
"\n",
"comparison_df = pd.concat([df, test_df])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5fp9gAX_V09_"
},
"source": [
"From there, we can use grouping to calculate statistics of interest:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NIGBxyZIV09_"
},
"outputs": [],
"source": [
"stats = comparison_df.groupby(\"environment\").describe()\n",
"\n",
"stats[\"detoxify.obscene(outputs.output_text)\"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2G2tVhhY1Q8k"
},
"source": [
"These descriptive statistics are helpful,\n",
"but as with our simple plot above,\n",
"we want to _look_ at the data.\n",
"\n",
"Exploratory data analysis is typically very visual --\n",
"the goal is to find phenomena so obvious\n",
"that statistical testing is an afterthought --\n",
"and so is exploratory model analysis.\n",
"\n",
"`matplotlib` is based on plotting arrays,\n",
"rather than `DataFrame`s or other tabular data,\n",
"so it's not a great fit on its own here,\n",
"unless we want to tolerate a lot of boilerplate.\n",
"\n",
"`pandas` has basic built-in plotting\n",
"that interfaces with `matplotlib`,\n",
"but it's not that ergonomic for comparisons or flexible\n",
"without just dropping back to matplotlib.\n",
"\n",
"There are a number of other Python plotting libraries,\n",
"many with an emphasis on share-ability and interaction\n",
"([Vega-Altair](https://altair-viz.github.io/),\n",
"[`bokeh`](http://bokeh.org/),\n",
"and\n",
"[Plotly](https://plotly.com/),\n",
"to name a few)\n",
"and others with an emphasis on usability\n",
"(e.g. [`ggplot`](https://realpython.com/ggplot-python/)).\n",
"\n",
"The one that we like for in-notebook analysis\n",
"that balances ease of use\n",
"on tabular data with flexibility is\n",
"[`seaborn`](https://seaborn.pydata.org/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7nZV8uoG1Q8k"
},
"source": [
"Comparing the distributions of the `detoxify.obscene` metric\n",
"is a single function call:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WnGxCz1f1Q8k"
},
"outputs": [],
"source": [
"import seaborn as sns\n",
"\n",
"\n",
"sns.displot( # plot the dis-tribution\n",
" data=comparison_df, # of data from this df\n",
" # specifically, this column, along the x-axis\n",
" x=\"detoxify.obscene(outputs.output_text)\",\n",
" # and split it up (in color/hue) by this column\n",
" hue=\"environment\"\n",
");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jO6FuRCQV0-A"
},
"source": [
"We can quickly see that the obscenity scores according to `detoxify`\n",
"are generally lower in our `prod`uction environment,\n",
"so we don't have a reason to suspect\n",
"our model is behaving too badly in production\n",
"-- though see the exercises for more on this!\n",
"\n",
"We can see the same thing\n",
"without having to write query, cleaning, and plotting code\n",
"[in the Gantry UI here](https://app.gantry.io/applications/fsdl-text-recognizer/distribution?view=2022-class&compare=test-ingest) --\n",
"note that viewing the dashboard requires a Gantry account,\n",
"which you can sign up for\n",
"[here](https://gantry.io/fsdl-signup)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iKZ0l2MCjlDn"
},
"source": [
"## Debugging the Text Recognizer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ovp8fZ1GpUet"
},
"source": [
"In our application,\n",
"we don't have user corrections or labels from annotators,\n",
"so we can't calculate an accuracy, a loss, or a character error rate.\n",
"\n",
"We instead look for signals that are correlated with\n",
"those values.\n",
"\n",
"This approach has limits\n",
"(see, e.g. the analysis in the\n",
"[MLDeMon paper](https://arxiv.org/abs/2104.13621))\n",
"and setting alerts or test failures on things that are only correlated with,\n",
"rather than directly caused by, poor performance is a bad idea.\n",
"\n",
"But it's very useful to have this information logged\n",
"to catch large errors at a glance\n",
"or to provide tools for slicing, filtering, and grouping data\n",
"while doing exploratory model analysis or debugging."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0YauDrY51Q8l"
},
"source": [
"We can also compute these signals with Gantry Projections.\n",
"\n",
"Low entropy (e.g. repetition) is a failure mode of language models,\n",
"as is excessively high entropy (e.g. uniformly random text).\n",
"\n",
"We can review the output text entropy distributions in\n",
"production and during testing\n",
"by plotting them against one another\n",
"(here or\n",
"[in the Gantry UI](https://app.gantry.io/applications/fsdl-text-recognizer/distribution?view=2022-class&compare=test-ingest))."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "czepR9o7l2FO"
},
"outputs": [],
"source": [
"sns.displot(\n",
" data=comparison_df,\n",
" x=\"text_stats.basics.entropy(outputs.output_text)\",\n",
" hue=\"environment\"\n",
");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8LiFvkoR1Q8l"
},
"source": [
"It appears there are more low-entropy strings in the model's outputs in production.\n",
"\n",
"With models that operate on human-relevant data,\n",
"like text and images,\n",
"it's important to look at the raw data,\n",
"not just projections.\n",
"\n",
"Let's take a look at a sample of outputs from the model running on test data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FQ9kTz2ZmOwR"
},
"outputs": [],
"source": [
"test_df[\"outputs.output_text\"].sample(10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BpZ_35uD1Q8l"
},
"source": [
"The results are not incredible, but they are recognizably \"English with typos\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NVlj3vYf1Q8l"
},
"source": [
"Let's look specifically at low entropy examples from production\n",
"(we can also view this\n",
"[filtered data in the Gantry UI](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=2022-class-low-entropy&compare=test-ingest))."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p0dkx1VzoJ9C"
},
"outputs": [],
"source": [
"df.loc[df[\"text_stats.basics.entropy(outputs.output_text)\"] < 5][\"outputs.output_text\"].sample(10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iMmcPuynV0-C"
},
"source": [
"Yikes! Lots of repetitive gibberish."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "stStBoCZ1Q8m"
},
"source": [
"Knowing the outputs are bad,\n",
"there are two culprits:\n",
"the input-output mapping (aka the model)\n",
"or the inputs."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nFaGYnjcmKf6"
},
"source": [
"We ran the same model in a similar environment\n",
"to get those outputs,\n",
"so it's most likely due to some difference in the inputs.\n",
"\n",
"Let's check them!\n",
"\n",
"We added Gantry Projections to look at the distribution of pixel values as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uSwnexFRlaIV"
},
"outputs": [],
"source": [
"sns.displot(\n",
" data=comparison_df,\n",
" x=\"image.greyscale_image_mean(inputs.image)\",\n",
" hue=\"environment\"\n",
");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iqkWkM45yMgV"
},
"source": [
"There's a huge difference in mean pixel values --\n",
"almost all images have mean intensities that are very dark in the testing environment,\n",
"but we see both dark and light images in production.\n",
"\n",
"Reviewing the\n",
"[raw data in Gantry](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=2022-class-low-entropy&compare=test-ingest)\n",
"confirms that we are getting images with very different brightnesses in production\n",
"and whiffing the predictions\n",
"-- along with images that reveal a number of other interesting failure modes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X5uWeR6n1Q8m"
},
"source": [
"To take a look locally,\n",
"we'll need to pull the images down from S3,\n",
"where they are stored."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NbNMlevz1Q8m"
},
"source": [
"The cell below defines a quick utility for\n",
"reading from S3 without authentication.\n",
"\n",
"It is based on the `smart_open` and `boto3` libraries,\n",
"which we briefly saw in the\n",
"[model deployment lab](https://fsdl.me/lab07-colab)\n",
"and the\n",
"[data annotation lab](https://fsdl.me/lab06-colab)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-FNIm0MOovtu"
},
"outputs": [],
"source": [
"import boto3\n",
"from botocore import UNSIGNED\n",
"from botocore.config import Config\n",
"import smart_open\n",
"\n",
"from text_recognizer.util import read_image_pil_file\n",
"\n",
"# spin up a client for communicating with s3 without authenticating (\"UNSIGNED\" activity)\n",
"s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))\n",
"unsigned_params = {\"client\": s3}\n",
"\n",
"def read_image_unsigned(image_uri, grayscale=False):\n",
" with smart_open.open(image_uri, \"rb\", transport_params=unsigned_params) as image_file:\n",
" return read_image_pil_file(image_file, grayscale)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SxBpmPYrV0-F"
},
"source": [
"Run the cell below to repeatedly sample a random input/output pair\n",
"flagged in production."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xy90rzcWobuk"
},
"outputs": [],
"source": [
"row = df.sample().iloc[0]\n",
"print(\"image url:\", row[\"inputs.image\"])\n",
"print(\"prediction:\", row[\"outputs.output_text\"])\n",
"read_image_unsigned(row[\"inputs.image\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oFdT2W2xtOGx"
},
"source": [
"### Take-aways for developing models\n",
"\n",
"The most immediate take-away from reviewing just a few examples is that\n",
"user data is way more heterogeneous than train/val/test data!\n",
"\n",
"This a\n",
"[fairly](https://browsee.io/blog/a-guide-to-session-replays-for-product-managers/)\n",
"[universal](https://medium.com/@beasles/edge-case-responsive-design-9b610138ddbd)\n",
"[finding](https://quoteinvestigator.com/2021/05/04/no-plan/).\n",
"\n",
"Let's also consider some specific failure modes in our case\n",
"and how we might resolve them:\n",
"\n",
"- Failure mode: Users mostly provide images with dark text on light background, but we train on dark background.\n",
" - Resolution: We could check image brightness and flip if needed,\n",
" but this feels like a cop-out -- most text is dark on a light background!\n",
" - Resolution: We add image brightness inversion to our train-time augmentations.\n",
"- Failure mode: Users expect our \"handwritten text recognition\" tool to work with printed and digital text.\n",
" - Resolution: We could try better sign-posting and user education,\n",
" but this is also something of a cop-out.\n",
" Users expect the tool to work on all text,\n",
" so we shouldn't violate that expectation.\n",
" - Resolution: We synthesize digital text data --\n",
" text rendering is a feature of just about any mature programming language.\n",
"- Failure mode: Users provide text on heterogeneous backgrounds\n",
" - Resolution: We collect or synthesize more heterogeneous data,\n",
" e.g. placing text (with or without background coloring)\n",
" on top of random image backgrounds.\n",
"- Failure mode: Users provide text with characters and symbols outside of our dictionary.\n",
" - Resolution: We can expand the model outputs and collect more heterogeneous data\n",
"- Failure mode: Users provide images with multiple blocks of text\n",
" - Resolution: We develop an architecture/task definition that can handle multiple regions.\n",
" We'll need to collect and/or synthesize data to support"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9rQH6zI8u7WN"
},
"source": [
"Notice: these are almost entirely changes to data,\n",
"and most of them involve collecting more or synthesizing it.\n",
"\n",
"This is very much typical!\n",
"\n",
"Data drives improvements to models,\n",
"[even at scale](https://www.lesswrong.com/posts/6Fpvch8RR29qLEWNH/chinchilla-s-wild-implications)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2P5MrIW5V0-F"
},
"source": [
"### Take-aways for exploratory model analysis"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mfMf1wwR1Q8n"
},
"source": [
"Notice that we had to write a lot of code,\n",
"which was developed and which we ran in a\n",
"tight interactive loop.\n",
"\n",
"This type of code is very hard to turn into scripts --\n",
"how do you trigger an alert on a plot? --\n",
"which makes it brittle and hard to version and share.\n",
"\n",
"It's also based on possibly very large-scale data artifacts.\n",
"\n",
"The right tool for this job is a UI\n",
"on top of a database.\n",
"\n",
"In the\n",
"[video walkthrough for this lab](https://fsdl.me/2022-lab-08-video),\n",
"we do the effectively the same analysis,\n",
"but inside Gantry,\n",
"which makes the process more fluid.\n",
"\n",
"Gantry is still in closed beta,\n",
"but if you're interested in applying it to your own applications, you can\n",
"[join the waitlist](https://gantry.io/waitlist/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M73gui0XhgCl"
},
"source": [
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mWWrmGiThhMw"
},
"source": [
"### 🌟 Examine the test data strings, both output and ground truth."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "km0nv0Mghmd_"
},
"source": [
"We compared our production obscenity metric to the test-time values of that same metric\n",
"and determined that we had not gotten worse,\n",
"so things were fine.\n",
"\n",
"But what if the test-time baseline is bad?\n",
"\n",
"Review the raw test ground truth data\n",
"[here](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=test-ingest),\n",
"if you\n",
"[signed up a Gantry account](https://gantry.io/fsdl-signup),\n",
"or by looking at the contents of `test_df` above.\n",
"\n",
"Sort by `detoxify.identity_attack(feedback.ground_truth_string)`\n",
"or filter to only high values of that metric.\n",
"\n",
"Review the example `feedback.ground_truth_string` texts and consider:\n",
"is this the subset of English\n",
"we want the model to be training on?\n",
"what objections might be raised to the contents?\n",
"\n",
"You might also look for cases where the `detoxify` models misunderstood meaning --\n",
"e.g. an innocuous use of a word that's often used objectionably."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1Q6mWRwS1Q8t"
},
"source": [
"### 🌟🌟 Start building \"regression testing suites\" by doing error analysis on these examples."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jfsCnjCg1Q8t"
},
"source": [
"Do this by going through feedback data one image/text pair at a time --\n",
"[in Gantry](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=2022-class-low-entrop)\n",
"or inside this notebook.\n",
"\n",
"Start by just taking notes on each example\n",
"(anywhere -- Google Sheets/Excel/Notion, or just a sheet of paper).\n",
"\n",
"The primary question you should ask is:\n",
"how does this example differ from the data shown in training?\n",
"\n",
"Check\n",
"[this W&B Artifact page](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/artifacts/run_table/run-1vrnrd8p-trainpredictions/v194/files/train/predictions.table.json#f5854c9c18f6c24a4e99)\n",
"to see what training data\n",
"(including augmentation)\n",
"looks like.\n",
"\n",
"Once you have some notes,\n",
"try and formalize them into a small number of \"failure modes\" --\n",
"you can choose to align them with the failure modes described in the section\n",
"on take-aways for model development or not.\n",
"\n",
"If you want to finish the loop,\n",
"you might set up Label Studio, as in\n",
"[the data annotation lab](https://fsdl.me/lab06-colab).\n",
"An annotator should add at least a\n",
"\"label\" that gives the type of issue\n",
"and perhaps also add a text annotation\n",
"while they are at it."
]
}
],
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: lab08/tasks/lint.sh
================================================
#!/bin/bash
set -uo pipefail
set +e
FAILURE=false
# apply automatic formatting
echo "black"
pre-commit run black || FAILURE=true
# check for python code style violations, see .flake8 for details
echo "flake8"
pre-commit run flake8 || FAILURE=true
# check for shell scripting style violations and common bugs
echo "shellcheck"
pre-commit run shellcheck || FAILURE=true
# check python types
echo "mypy"
pre-commit run mypy || FAILURE=true
if [ "$FAILURE" = true ]; then
echo "Linting failed"
exit 1
fi
echo "Linting passed"
exit 0
================================================
FILE: lab08/text_recognizer/__init__.py
================================================
"""Modules for creating and running a text recognizer."""
================================================
FILE: lab08/text_recognizer/callbacks/__init__.py
================================================
from .model import ModelSizeLogger
from .optim import LearningRateMonitor
from . import imtotext
from .imtotext import ImageToTextTableLogger as ImageToTextLogger
================================================
FILE: lab08/text_recognizer/callbacks/imtotext.py
================================================
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
from .util import check_and_warn
class ImageToTextTableLogger(pl.Callback):
"""Logs the inputs and outputs of an image-to-text model to Weights & Biases."""
def __init__(self, max_images_to_log=32, on_train=True):
super().__init__()
self.max_images_to_log = min(max(max_images_to_log, 1), 32)
self.on_train = on_train
self._required_keys = ["gt_strs", "pred_strs"]
@rank_zero_only
def on_train_batch_end(self, trainer, module, output, batch, batch_idx):
if self.on_train:
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_table", "image-to-text table"):
return
else:
self._log_image_text_table(trainer, output, batch, "train/predictions")
@rank_zero_only
def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_table", "image-to-text table"):
return
else:
self._log_image_text_table(trainer, output, batch, "validation/predictions")
def _log_image_text_table(self, trainer, output, batch, key):
xs, _ = batch
gt_strs = output["gt_strs"]
pred_strs = output["pred_strs"]
mx = self.max_images_to_log
xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx]
xs = [wandb.Image(x) for x in xs]
rows = zip(*[xs, gt_strs, pred_strs])
columns = ["input_image", "ground_truth_string", "predicted_string"]
trainer.logger.log_table(key=key, columns=columns, data=list(rows))
def has_metrics(self, output):
return all(key in output.keys() for key in self._required_keys)
class ImageToTextCaptionLogger(pl.Callback):
"""Logs the inputs and outputs of an image-to-text model to Weights & Biases."""
def __init__(self, max_images_to_log=32, on_train=True):
super().__init__()
self.max_images_to_log = min(max(max_images_to_log, 1), 32)
self.on_train = on_train
self._required_keys = ["gt_strs", "pred_strs"]
@rank_zero_only
def on_train_batch_end(self, trainer, module, output, batch, batch_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "train/predictions")
@rank_zero_only
def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "validation/predictions")
@rank_zero_only
def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx):
if self.has_metrics(output):
if check_and_warn(trainer.logger, "log_image", "image-to-text"):
return
else:
self._log_image_text_caption(trainer, output, batch, "test/predictions")
def _log_image_text_caption(self, trainer, output, batch, key):
xs, _ = batch
gt_strs = output["gt_strs"]
pred_strs = output["pred_strs"]
mx = self.max_images_to_log
xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx]
trainer.logger.log_image(key, xs, caption=pred_strs)
def has_metrics(self, output):
return all(key in output.keys() for key in self._required_keys)
================================================
FILE: lab08/text_recognizer/callbacks/model.py
================================================
import os
from pathlib import Path
import tempfile
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import torch
from .util import check_and_warn, logging
try:
import torchviz
has_torchviz = True
except ImportError:
has_torchviz = False
class ModelSizeLogger(pl.Callback):
"""Logs information about model size (in parameters and on disk)."""
def __init__(self, print_size=True):
super().__init__()
self.print_size = print_size
@rank_zero_only
def on_fit_start(self, trainer, module):
self._run(trainer, module)
def _run(self, trainer, module):
metrics = {}
metrics["mb_disk"] = self.get_model_disksize(module)
metrics["nparams"] = count_params(module)
if self.print_size:
print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB")
metrics = {f"size/{key}": value for key, value in metrics.items()}
trainer.logger.log_metrics(metrics, step=-1)
@staticmethod
def get_model_disksize(module):
"""Determine the model's size on disk by saving it to disk."""
with tempfile.NamedTemporaryFile() as f:
torch.save(module.state_dict(), f)
size_mb = os.path.getsize(f.name) / 1e6
return size_mb
class GraphLogger(pl.Callback):
"""Logs a compute graph as an image."""
def __init__(self, output_key="logits"):
super().__init__()
self.graph_logged = False
self.output_key = output_key
if not has_torchviz:
raise ImportError("GraphLogCallback requires torchviz." "")
@rank_zero_only
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx):
if not self.graph_logged:
try:
outputs = outputs[0][0]["extra"]
self.log_graph(trainer, module, outputs[self.output_key])
except KeyError:
logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}")
self.graph_logged = True
@staticmethod
def log_graph(trainer, module, outputs):
if check_and_warn(trainer.logger, "log_image", "graph"):
return
params_dict = dict(list(module.named_parameters()))
graph = torchviz.make_dot(outputs, params=params_dict)
graph.format = "png"
fname = Path(trainer.logger.experiment.dir) / "graph"
graph.render(fname)
fname = str(fname.with_suffix("." + graph.format))
trainer.logger.log_image(key="graph", images=[fname])
def count_params(module):
"""Counts the number of parameters in a Torch Module."""
return sum(p.numel() for p in module.parameters())
================================================
FILE: lab08/text_recognizer/callbacks/optim.py
================================================
import pytorch_lightning as pl
KEY = "optimizer"
class LearningRateMonitor(pl.callbacks.LearningRateMonitor):
"""Extends Lightning's LearningRateMonitor with a prefix.
Logs the learning rate during training. See the docs for
pl.callbacks.LearningRateMonitor for details.
"""
def _add_prefix(self, *args, **kwargs) -> str:
return f"{KEY}/" + super()._add_prefix(*args, **kwargs)
================================================
FILE: lab08/text_recognizer/callbacks/util.py
================================================
import logging
logging.basicConfig(level=logging.WARNING)
def check_and_warn(logger, attribute, feature):
if not hasattr(logger, attribute):
warn_no_attribute(feature, attribute)
return True
def warn_no_attribute(blocked_feature, missing_attribute):
logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.")
================================================
FILE: lab08/text_recognizer/data/__init__.py
================================================
"""Module containing submodules for each dataset.
Each dataset is defined as a class in that submodule.
The datasets should have a .config method that returns
any configuration information needed by the model.
Most datasets define their constants in a submodule
of the metadata module that is parallel to this one in the
hierarchy.
"""
from .util import BaseDataset
from .base_data_module import BaseDataModule
from .mnist import MNIST
from .emnist import EMNIST
from .emnist_lines import EMNISTLines
from .iam_paragraphs import IAMParagraphs
from .iam_lines import IAMLines
from .fake_images import FakeImageData
from .iam_synthetic_paragraphs import IAMSyntheticParagraphs
from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs
================================================
FILE: lab08/text_recognizer/data/base_data_module.py
================================================
"""Base DataModule class."""
import argparse
import os
from pathlib import Path
from typing import Collection, Dict, Optional, Tuple, Union
import pytorch_lightning as pl
import torch
from torch.utils.data import ConcatDataset, DataLoader
from text_recognizer import util
from text_recognizer.data.util import BaseDataset
import text_recognizer.metadata.shared as metadata
def load_and_print_info(data_module_class) -> None:
"""Load EMNISTLines and print info."""
parser = argparse.ArgumentParser()
data_module_class.add_to_argparse(parser)
args = parser.parse_args()
dataset = data_module_class(args)
dataset.prepare_data()
dataset.setup()
print(dataset)
def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return filename
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename
BATCH_SIZE = 128
NUM_AVAIL_CPUS = len(os.sched_getaffinity(0))
NUM_AVAIL_GPUS = torch.cuda.device_count()
# sensible multiprocessing defaults: at most one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS
# but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU
DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS
class BaseDataModule(pl.LightningDataModule):
"""Base for all of our LightningDataModules.
Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html
"""
def __init__(self, args: argparse.Namespace = None) -> None:
super().__init__()
self.args = vars(args) if args is not None else {}
self.batch_size = self.args.get("batch_size", BATCH_SIZE)
self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS)
self.on_gpu = isinstance(self.args.get("gpus", None), (str, int))
# Make sure to set the variables below in subclasses
self.input_dims: Tuple[int, ...]
self.output_dims: Tuple[int, ...]
self.mapping: Collection
self.data_train: Union[BaseDataset, ConcatDataset]
self.data_val: Union[BaseDataset, ConcatDataset]
self.data_test: Union[BaseDataset, ConcatDataset]
@classmethod
def data_dirname(cls):
return metadata.DATA_DIRNAME
@staticmethod
def add_to_argparse(parser):
parser.add_argument(
"--batch_size",
type=int,
default=BATCH_SIZE,
help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.",
)
parser.add_argument(
"--num_workers",
type=int,
default=DEFAULT_NUM_WORKERS,
help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.",
)
return parser
def config(self):
"""Return important settings of the dataset, which will be passed to instantiate models."""
return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping}
def prepare_data(self, *args, **kwargs) -> None:
"""Take the first steps to prepare data for use.
Use this method to do things that might write to disk or that need to be done only from a single GPU
in distributed settings (so don't set state `self.x = y`).
"""
def setup(self, stage: Optional[str] = None) -> None:
"""Perform final setup to prepare data for consumption by DataLoader.
Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting.
Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
"""
def train_dataloader(self):
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def val_dataloader(self):
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
def test_dataloader(self):
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.on_gpu,
)
================================================
FILE: lab08/text_recognizer/data/emnist.py
================================================
"""EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present."""
import json
import os
from pathlib import Path
import shutil
from typing import Sequence
import zipfile
import h5py
import numpy as np
import toml
from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info
from text_recognizer.data.util import BaseDataset, split_dataset
import text_recognizer.metadata.emnist as metadata
from text_recognizer.stems.image import ImageStem
from text_recognizer.util import temporary_working_directory
NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS
RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME
METADATA_FILENAME = metadata.METADATA_FILENAME
DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME
PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME
PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME
ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME
SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class.
TRAIN_FRAC = 0.8
class EMNIST(BaseDataModule):
"""EMNIST dataset of handwritten characters and digits.
"The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19
and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset."
From https://www.nist.gov/itl/iad/image-group/emnist-dataset
The data split we will use is
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
def __init__(self, args=None):
super().__init__(args)
self.mapping = metadata.MAPPING
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
self.transform = ImageStem()
self.input_dims = metadata.DIMS
self.output_dims = metadata.OUTPUT_DIMS
def prepare_data(self, *args, **kwargs) -> None:
if not os.path.exists(PROCESSED_DATA_FILENAME):
_download_and_process_emnist()
def setup(self, stage: str = None) -> None:
if stage == "fit" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_trainval = f["x_train"][:]
self.y_trainval = f["y_train"][:].squeeze().astype(int)
data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform)
self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42)
if stage == "test" or stage is None:
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
self.x_test = f["x_test"][:]
self.y_test = f["y_test"][:].squeeze().astype(int)
self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform)
def __repr__(self):
basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n"
if self.data_train is None and self.data_val is None and self.data_test is None:
return basic
x, y = next(iter(self.train_dataloader()))
data = (
f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
)
return basic + data
def _download_and_process_emnist():
metadata = toml.load(METADATA_FILENAME)
_download_raw_dataset(metadata, DL_DATA_DIRNAME)
_process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME)
def _process_raw_dataset(filename: str, dirname: Path):
print("Unzipping EMNIST...")
with temporary_working_directory(dirname):
with zipfile.ZipFile(filename, "r") as zf:
zf.extract("matlab/emnist-byclass.mat")
from scipy.io import loadmat
# NOTE: If importing at the top of module, would need to list scipy as prod dependency.
print("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2)
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
# NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices
if SAMPLE_TO_BALANCE:
print("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
print("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf")
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
print("Saving essential dataset parameters to text_recognizer/data...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(list(mapping.values()))
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
with open(ESSENTIALS_FILENAME, "w") as f:
json.dump(essentials, f)
print("Cleaning up...")
shutil.rmtree("matlab")
def _sample_to_balance(x, y):
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(42)
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_inds = []
for label in np.unique(y.flatten()):
inds = np.where(y == label)[0]
sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
all_sampled_inds.append(sampled_inds)
ind = np.concatenate(all_sampled_inds)
x_sampled = x[ind]
y_sampled = y[ind]
return x_sampled, y_sampled
def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]:
"""Augment the mapping with extra symbols."""
# Extra characters from the IAM dataset
iam_characters = [
" ",
"!",
'"',
"#",
"&",
"'",
"(",
")",
"*",
"+",
",",
"-",
".",
"/",
":",
";",
"?",
]
# Also add special tokens:
# - CTC blank token at index 0
# - Start token at index 1
# - End token at index 2
# - Padding token at index 3
# NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this!
return ["", "
", "", " and ", *tokens, " and ", *tokens, ""]
self.end_index = self.inverse_mapping["",
""]
self.end_token = inverse_mapping[""]
self.end_token = inverse_mapping[""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MX9n-Zed8G_T"
},
"source": [
"# Lab 00: The 🥞 Full Stack 🥞 of the Text Recognizer Application"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OggjLhU3f9gk"
},
"source": [
"In the course of these labs,\n",
"you will build an optical character recognition (OCR) application\n",
"that is powered by a neural network:\n",
"the \"FSDL Text Recognizer\".\n",
"\n",
"We use this application to\n",
"- demonstrate general principles for engineering an ML-powered application,\n",
"- provide a \"worked example\" that includes all of the juicy details, and\n",
"- introduce you to tools, libraries, and practices that we consider best-in-class or best for independent ML engineers working across the full stack.\n",
"\n",
"You can try it out inside this notebook below,\n",
"or you can simply navigate to the `app_url` in your browser."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "g9xKjSYie6ck"
},
"outputs": [],
"source": [
"from IPython.display import IFrame\n",
"\n",
"app_url = \"https://fsdl-text-recognizer.ngrok.io/\"\n",
"\n",
"IFrame(app_url, width=1024, height=896)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaDkEosIjcl6"
},
"source": [
"## Frontend and Backend"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cDxvNgFHgM_J"
},
"source": [
"What you see above is the \"frontend\",\n",
"the user-facing component, of the application.\n",
"\n",
"Frontend web development is typically done using\n",
"Javascript as the programming language.\n",
"Most ML is done in Python (see below),\n",
"so we will instead build our frontend using\n",
"the Python library [**Gradio**](https://gradio.app/).\n",
"\n",
"> Another excellent choice for pure Python web development might be\n",
"[Streamlit](https://streamlit.io/)\n",
"or even, in the near future, tools built around\n",
"[PyScript](https://pyscript.net/).\n",
"\n",
"Notice the option to \"flag\" the model's outputs.\n",
"This user feedback will be sent to [**Gantry**](https://gantry.io/),\n",
"where we can monitor model performance,\n",
"generate alerts,\n",
"and do exploratory data analysis.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ywyH6kW5uUjH"
},
"source": [
"\n",
"The model that reads the image to produce the text\n",
"is not running\n",
"in the same place as the frontend.\n",
"The model is the \"backend\" of our application.\n",
"We separate the two via a JSON API.\n",
"The model is deployed\n",
"[serverlessly](https://serverless-stack.com/chapters/what-is-serverless.html)\n",
"to Amazon Web Services using\n",
"[**AWS Lambda**](https://aws.amazon.com/lambda/),\n",
"which runs a\n",
"[**Docker**](https://docker-curriculum.com/)\n",
"container that wraps up our model.\n",
"\n",
"> Docker is the tool of choice for virtualization/containerization. As containerized applications become more complex,\n",
"[container orchestration](https://www.vmware.com/topics/glossary/content/container-management.html)\n",
"becomes important. The premier tool for orchestrating\n",
"Docker containers is\n",
"[kubernetes](https://kubernetes.io/), aka k8s.\n",
"Non-experts on cloud infrastructure will want to use their providers' managed service for k8s, e.g.\n",
"[AWS EKS](https://aws.amazon.com/eks/)\n",
"or [Google Kubernetes Engine](https://cloud.google.com/kubernetes-engine).\n",
"\n",
"The container image lives inside the\n",
"[Elastic Container Registry](https://aws.amazon.com/ecr/),\n",
"a sort of \"GitHub for Docker\" on AWS.\n",
"The choice to go serverless makes it effortless to scale up our model\n",
"across a number of orders of magnitude\n",
"and the choice to containerize reduces friction and error\n",
"when moving our model from development to production.\n",
"\n",
"> This could equally as well be done on another cloud,\n",
"like [Google Cloud Platform](https://cloud.google.com/)\n",
"or [Microsoft Azure](https://azure.microsoft.com/en-us/),\n",
"which offer serverless deployment via\n",
"[Google Cloud Functions](https://cloud.google.com/serverless)\n",
"and [Azure Functions](https://azure.microsoft.com/en-us/solutions/serverless),\n",
"respectively. "
]
},
{
"cell_type": "markdown",
"source": [
"The backend operates completely independently of the frontend,\n",
"which means it can be used in multiple contexts.\n",
"\n",
"Run the cell below to send a query directly to the model on the backend."
],
"metadata": {
"id": "3XBHox87IJ8i"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "76HwEP2Vzz3F"
},
"outputs": [],
"source": [
"import json # JavaScript Object Notation is the lingua franca of the web\n",
"\n",
"from IPython.display import Image\n",
"import requests # requests is the preferred library for web requests in Python\n",
"\n",
"lambda_url = \"https://3akxma777p53w57mmdika3sflu0fvazm.lambda-url.us-west-1.on.aws/\"\n",
"image_url = \"https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png\"\n",
"\n",
"headers = {\"Content-type\": \"application/json\"} # headers ensure our request is handled correctly\n",
"payload = json.dumps({\"image_url\": image_url}) # the request content is a string representation of JSON data\n",
"\n",
"if \"pred\" not in locals(): # a poor man's cache: if we've defined the variable pred, skip the request\n",
" response = requests.post( # we POST the image to the URL, expecting a prediction as a response\n",
" lambda_url, data=payload, headers=headers)\n",
" pred = response.json()[\"pred\"] # the response is also json\n",
"\n",
"print(pred)\n",
"\n",
"Image(url=image_url, width=512)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "csthw0QlgeSy"
},
"source": [
"## Application Diagram"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "baYhDRKkggNk"
},
"source": [
"We're only halfway through describing how the Text Recognizer works\n",
"and it's already getting hard to hold the whole thing in-memory.\n",
"\n",
"Run the cell below to show a diagram that incorporates the entire\n",
"process for creating and running the Text Recognizer,\n",
"from training to feedback collection."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bsOa6gQ0YhX4"
},
"outputs": [],
"source": [
"diagram_url = \"https://miro.com/app/live-embed/uXjVOrOHcOg=/?moveToViewport=-1210,-1439,2575,1999\"\n",
"\n",
"IFrame(diagram_url, width=1024, height=512)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RiQgHY6Th67H"
},
"source": [
"## Model Training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ib6ijsumjjlm"
},
"source": [
"Let's start back at the beginning -- developing a model.\n",
"We'll then make our way back to where we left off above, the handoff\n",
"from model development/training\n",
"to the actual application.\n",
"\n",
"We begin by training a neural network\n",
"(a [ResNet](https://pytorch.org/hub/pytorch_vision_resnet/)\n",
"encoder to process the images and \n",
"a [Transformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)\n",
"decoder to produce the output text).\n",
"\n",
"Neural networks operate by applying\n",
"sequences of large matrix multiplications\n",
"and other array operations.\n",
"These operations are much faster on GPUs than on CPUs\n",
"and are relatively easy to parallelize\n",
"across GPUs.\n",
"This is especially true during training,\n",
"where many inputs are processed in parallel,\n",
"or \"batched\" together.\n",
"\n",
"Purchasing GPUs and properly setting up\n",
"a multi-GPU machine is challenging\n",
"and has high up-front costs.\n",
"So we run our training via a cloud provider,\n",
"specifically\n",
"[**Lambda Labs GPU Cloud**](https://lambdalabs.com/service/gpu-cloud).\n",
"\n",
"> Other cloud providers offer GPU-accelerated compute\n",
"but Lambda Labs offers it at the lowest prices,\n",
"as of August 2022.\n",
"Larger organizations may benefit from the extra features\n",
"that integration with larger cloud providers,\n",
"like AWS or GCP, can provide (e.g. unified authorization\n",
"and control planes).\n",
"Because independent, full-stack developers\n",
"are often very price-sensitive, we recommend Lambda Labs --\n",
"even more, we recommend checking current and historical instance prices.\n",
"\n",
"\n",
"For smaller units of work, like debugging and quick experiments,\n",
"we can use\n",
"[Google Colaboratory](https://research.google.com/colaboratory/),\n",
"which provides limited access to free GPU (and TPU)\n",
"compute in an ephemeral environment.\n",
"\n",
"> For small-to-medium-sized deep learning tasks,\n",
"Colab Pro (\\$10/mo.) and Colab Pro+ (\\$50/mo.)\n",
"can be competitive with the larger cloud providers.\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LUS001EIv3H7"
},
"source": [
"If you're running this notebook on a machine with a GPU,\n",
"e.g. on Colab, running the cell below\n",
"will show some basic information on the GPU's state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WyYVgQmlv091"
},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C-C6iWKAsZI3"
},
"source": [
"Because the heavy work is done on the GPU,\n",
"using lower-level libraries,\n",
"we don't need to write the majority of our model development code\n",
"in a performant language like C/C++ or Rust.\n",
"\n",
"We can instead write in a more comfortable, but slower language:\n",
"it doesn't make sense to drive an F1 car to the airport\n",
"for an international flight.\n",
"\n",
"The language of choice for deep learning is\n",
"[**Python**](https://www.python.org/).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FQALjrGVwFeG"
},
"outputs": [],
"source": [
"import this # The Zen of Python"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-uqfsWUQwEyl"
},
"source": [
"We don't want to write our Python library for GPU acceleration from scratch,\n",
"especially because we also need automatic differentiation --\n",
"the ability to take derivatives of our neural networks.\n",
"The Python/C++ library\n",
"[PyTorch](https://pytorch.org/)\n",
"offers GPU-accelerated array math with automatic differentiation,\n",
"plus special neural network primitives and architectures.\n",
"\n",
"> There are two major alternatives to PyTorch\n",
"for providing accelerated, differentiable array math,\n",
"both from Google: early mover\n",
"[TensorFlow](https://www.tensorflow.org/resources/learn-ml)\n",
"and new(ish)comer\n",
"[JAX](https://github.com/google/jax).\n",
"The former is more common in certain larger, older enterprise settings\n",
"and the latter is more common in certain bleeding-edge research settings.\n",
"We choose PyTorch to split the difference,\n",
"but can confidently recommend all three.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qcvCJ6b1wVRl"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # run on GPU if available\n",
"\n",
"# create an array/tensor and track its gradients during calculations\n",
"a = torch.tensor([1.], requires_grad=True) \\\n",
" .to(device) # store the array data on GPU (if available)\n",
"b = torch.tensor([2.]).to(device)\n",
"\n",
"# calculate new values, building up a \"compute graph\"\n",
"c = a * b + a\n",
"\n",
"# compute gradient of c with respect to a by \"tracing the graph backwards\"\n",
"g, = torch.autograd.grad(outputs=c, inputs=a)\n",
"\n",
"g"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4zjQyN4HwUS0"
},
"source": [
"\n",
"PyTorch provides a number of features required for creating\n",
"deep neural networks,\n",
"but it doesn't include a high-level framework\n",
"for training or any of a number of related engineering tasks,\n",
"like metric calculation or model checkpointing.\n",
"\n",
"We use the\n",
"[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/)\n",
"library as our high-level training engineering framework.\n",
"\n",
"> PyTorch Lightning is the framework of choice\n",
"for generic deep learning in PyTorch,\n",
"but in natural language processing,\n",
"many people instead choose libraries from\n",
"[Hugging Face](https://hugginface.co/).\n",
"[Keras](https://keras.io/)\n",
"is the framework of choice for TensorFlow.\n",
"In some ways,\n",
"[Flax](https://github.com/google/flax)\n",
"is the same for JAX;\n",
"in others, there is not as of July 2022 a high-level\n",
"training framework in JAX."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ehvGUApGpnrV"
},
"outputs": [],
"source": [
"from IPython.display import YouTubeVideo\n",
"\n",
"lit_video_id = \"QHww1JH7IDU\"\n",
"YouTubeVideo(lit_video_id, modestbranding=True, rel=False, width=512)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7gu9b4Ux1U-k"
},
"source": [
"## Experiment and Artifact Tracking"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vWhTbHq3sfON"
},
"source": [
"ML models are challenging to debug:\n",
"their inputs and outputs are often easy for humans to interpret\n",
"but hard for traditional software programs to understand.\n",
"\n",
"They are also challenging to design:\n",
"there are a number of knobs to twiddle and constants to set,\n",
"like a finicky bunch of compiler flags.\n",
"These are known as \"hyperparameters\".\n",
"\n",
"So building an ML model often looks a bit less like engineering\n",
"and a bit more like experimentation.\n",
"These experiments need to be tracked,\n",
"as do large binary files,\n",
"or artifacts,\n",
"that are produced during those experiments\n",
"-- like model weights.\n",
"\n",
"We choose\n",
"[Weights & Biases](http://docs.wandb.ai)\n",
"as our experiment and artifact tracking platform.\n",
"\n",
"> [MLFlow](https://github.com/mlflow/mlflow)\n",
"is an open-source library that provides similar\n",
"features to W&B, but the experiment and artifact\n",
"tracking server must be self-hosted,\n",
"which can be burdensome for the already beleaguered\n",
"full-stack ML developer.\n",
"Basic experiment tracking can also be done\n",
"using [Tensorboard](https://www.tensorflow.org/tensorboard),\n",
"and shared using [tensorboard.dev](https://tensorboard.dev/),\n",
"but Tensorboard does not provide artifact tracking.\n",
"Artifact tracking and versioning can be done using\n",
"[Git LFS](https://git-lfs.github.com/),\n",
"but storage and distribution via GitHub can be expensive\n",
"and it does not provide experiment tracking.\n",
"[Hugging Face](https://huggingface.co/) runs an alternative\n",
"git server, Hugging Face Spaces,\n",
"that can display Tensorboard experiments and\n",
"mandates Git LFS for large files (where large means >10MB).\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rL1uL-SewukM"
},
"source": [
"The resulting experiment logs can be made very rich\n",
"and are invaluable for debugging\n",
"(e.g. tracking bugs through the git history)\n",
"and communicating results inside and across teams.\n",
"\n",
"Run the cell below to display the logs from an experiment\n",
"that was run while designing and debugging the Text Recognizer model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Uw4LUYkgwvb0"
},
"outputs": [],
"source": [
"experiment_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/runs/lfjjnxw8\"\n",
"\n",
"IFrame(experiment_url, width=1024, height=768)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O5WHO_CTwrgf"
},
"source": [
"Logged _data_ is inert.\n",
"It becomes usable, actionable _information_\n",
"when it is given context and form.\n",
"\n",
"Run the cell below to take a look at a dashboard,\n",
"built inside W&B,\n",
"reporting the results of a training run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W0V-qbh8uWwb"
},
"outputs": [],
"source": [
"dashboard_url= \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n",
"\n",
"IFrame(dashboard_url, width=1024, height=768)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R3I60PY61IXH"
},
"source": [
"## The Handoff to Production"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IsT1P1UG1hXW"
},
"source": [
"PyTorch Lightning produces large artifacts called \"checkpoints\"\n",
"that can be used to restart model training when it stops or is interrupted\n",
"(which allows the use of much cheaper\n",
"[\"preemptible\" cloud instances](https://www.determined.ai/blog/scale-your-model-development-on-a-budget)).\n",
"\n",
"We store these artifacts on Weights & Biases.\n",
"\n",
"When they are ready to be deployed to production,\n",
"we compile these model checkpoints down to a dialect of Torch called\n",
"[torchscript](https://pytorch.org/docs/stable/jit.html)\n",
"that is more portable:\n",
"it drops the training engineering code\n",
"and produces an artifact that is executable in C++ or in Python.\n",
"We stick with a Python environment for simplicity.\n",
"\n",
"> TensorFlow has similar facilities\n",
"for delivering models, including\n",
"[tensorflow.js](https://www.tensorflow.org/js)\n",
"and [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx).\n",
"There are also a number of alternative portable runtime environments\n",
"for ML models, including\n",
"[ONNX RT](https://onnx.ai/).\n",
"\n",
"These deployable models are also stored on Weights & Biases,\n",
"which connects them to rich metadata,\n",
"including the experiments and training runs\n",
"that produced the checkpoints from which they were derived.\n",
"\n",
"Run the cell below to review the metadata for a deployable\n",
"version of the Text Recognizer model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AB8OYJ423Qvy"
},
"outputs": [],
"source": [
"artifact_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/artifacts/prod-ready/paragraph-text-recognizer/v8\"\n",
"\n",
"IFrame(artifact_url, width=1024, height=768)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4O6VGIqM3ugW"
},
"source": [
"We can pull this file down,\n",
"package it into a Docker container\n",
"via a small Python script,\n",
"and ship it off to a container registry, like AWS ECR or Docker Hub, so that\n",
"it can be used to provide the backend to our application."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CS-1UA1s1hsl"
},
"source": [
"## Application Diagram, Redux"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eqr39GhA313z"
},
"source": [
"Now that we have made it through the\n",
"🥞 full stack 🥞 of the Text Recognizer application,\n",
"let's take a look at the application diagram again."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SH1acJ8f1kpP"
},
"outputs": [],
"source": [
"IFrame(diagram_url, width=1024, height=512)"
]
},
{
"cell_type": "markdown",
"source": [
"Over the remainder of the labs,\n",
"we will put all of these pieces together,\n",
"learning more about the problems they solve,\n",
"the tradeoffs they make,\n",
"and how they are best used."
],
"metadata": {
"id": "j_Yy4d3Dpi3o"
}
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "Lab 00 - Overview.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3.7.13 ('fsdl-text-recognizer-2022')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"vscode": {
"interpreter": {
"hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: pyproject.toml
================================================
[tool.flake8] # configured in .flake8
[tool.darglint] # configured in .flake8
[tool.black]
line-length = 120
target-version = ["py37"]
[tool.mypy]
ignore_missing_imports = true
exclude = ["training/logs"]
[tool.pytest.ini_options]
markers = [
"slow: marks a test as slow (deselect with '-m \"not slow\"']",
"data: marks a test as dependent on a data download (deselect with '-m \"not data\"')"
]
addopts = "--cov training --cov text_recognizer --cov-branch --doctest-modules --ignore training/logs -m 'not data' --ignore-glob **/bootstrap.py"
================================================
FILE: readme.md
================================================
# 🥞 Full Stack Deep Learning Fall 2022 Labs
Welcome!
As part of Full Stack Deep Learning 2022, we will incrementally develop a complete deep learning codebase to create and deploy a model that understands the content of hand-written paragraphs.
For an overview of the Text Recognizer application architecture, click the badge below to open an interactive Jupyter notebook on Google Colab:
We will use the modern stack of [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://www.pytorchlightning.ai/).
We will use the main workhorses of DL today: CNNs and Transformers.
We will manage our experiments using what we believe to be the best tool for the job: [Weights & Biases](https://docs.wandb.ai/).
We will set up a quality assurance and continuous integration system for our codebase using [pre-commit](https://pre-commit.com/) and [GitHub Actions](https://docs.github.com/en/actions).
We will package up the prediction system and deploy it as a [Docker](https://docs.docker.com/) container on [AWS Lambda](https://aws.amazon.com/lambda/).
We will wrap that prediction system in a frontend written in Python using [Gradio](https://gradio.app/docs).
We will set up monitoring that alerts us to potential issues in our model using [Gantry](https://gantry.io/).
## Click the badges below to access individual lab notebooks on Colab and videos on YouTube
| Lab | Colab | Video |
| :-- | :---: | :---: |
| **Lab Overview** | [![open-in-colab]](https://fsdl.me/lab00-colab) | [![yt-logo]](https://fsdl.me/2022-lab-overview-video) |
| **Lab 01: Deep Neural Networks in PyTorch** | [![open-in-colab]](https://fsdl.me/lab01-colab) | [![yt-logo]](https://fsdl.me/2022-lab-01-video) |
| **Lab 02a: PyTorch Lightning** | [![open-in-colab]](https://fsdl.me/lab02a-colab) | [![yt-logo]](https://fsdl.me/2022-lab-02-video) |
| **Lab 02b: Training a CNN on Synthetic Handwriting Data** | [![open-in-colab]](https://fsdl.me/lab02b-colab) | [![yt-logo]](https://fsdl.me/2022-lab-02-video) |
| **Lab 03: Transformers and Paragraphs** | [![open-in-colab]](https://fsdl.me/lab03-colab) | [![yt-logo]](https://fsdl.me/2022-lab-03-video) |
| **Lab 04: Experiment Tracking** | [![open-in-colab]](https://fsdl.me/lab04-colab) | [![yt-logo]](https://fsdl.me/2022-lab-04-video) |
| **Lab 05: Troubleshooting & Testing** | [![open-in-colab]](https://fsdl.me/lab05-colab) | [![yt-logo]](https://fsdl.me/2022-lab-05-video) |
| **Lab 06: Data Annotation** | [![open-in-colab]](https://fsdl.me/lab06-colab) | [![yt-logo]](https://fsdl.me/2022-lab-06-video) |
| **Lab 07: Deployment** | [![open-in-colab]](https://fsdl.me/lab07-colab) | [![yt-logo]](https://fsdl.me/2022-lab-07-video) |
| **Lab 08: Monitoring** | [![open-in-colab]](https://fsdl.me/lab08-colab) | [![yt-logo]](https://fsdl.me/2022-lab-08-video) |
[yt-logo]: https://fsdl.me/yt-logo-badge
[open-in-colab]: https://colab.research.google.com/assets/colab-badge.svg
================================================
FILE: requirements/dev-lint.in
================================================
-c prod.txt
-c dev.txt
bandit==1.7.4
black==22.3.0
darglint==1.8.1
flake8<4
flake8-bandit==3.0.0
flake8-bugbear==22.4.25
flake8-docstrings==1.6.0
flake8-import-order==0.18.1
mypy==0.960
# mypy version also pinned in .pre-commit-config.yaml
safety==1.10.3
shellcheck-py==0.8.0.4
types-toml==0.10.7
================================================
FILE: requirements/dev.in
================================================
-c prod.txt
boltons
coverage[toml]
defusedxml
itermplot
ipywidgets
matplotlib
nltk
pre-commit
pytest
pytest-cov
scipy
toml
# versioned to give pip hints
coverage[toml]==6.4
pytest==7.1.1
pytest-cov==3.0.0
# versioned to match Google Colab
notebook>=6.5,<6.6
nbconvert>=6.5,<6.6
seaborn>=0.13,<0.14
tornado>=6.3,<6.4
# versioned to improve stability
pytorch-lightning==1.6.3
torchmetrics<0.8
wandb==0.12.17
================================================
FILE: requirements/dev.txt
================================================
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile requirements/dev.in
#
absl-py==1.4.0
# via tensorboard
aiohttp==3.8.5
# via
# -c requirements/prod.txt
# fsspec
aiosignal==1.3.1
# via
# -c requirements/prod.txt
# aiohttp
anyio==3.7.1
# via
# -c requirements/prod.txt
# jupyter-server
argon2-cffi==23.1.0
# via
# jupyter-server
# nbclassic
# notebook
argon2-cffi-bindings==21.2.0
# via argon2-cffi
arrow==1.2.3
# via isoduration
asttokens==2.2.1
# via stack-data
async-timeout==4.0.3
# via
# -c requirements/prod.txt
# aiohttp
attrs==23.1.0
# via
# -c requirements/prod.txt
# aiohttp
# jsonschema
# pytest
# referencing
backcall==0.2.0
# via ipython
beautifulsoup4==4.12.2
# via nbconvert
bleach==6.0.0
# via nbconvert
boltons==23.0.0
# via -r requirements/dev.in
cachetools==4.2.4
# via
# -c requirements/prod.txt
# google-auth
certifi==2023.7.22
# via
# -c requirements/prod.txt
# requests
# sentry-sdk
cffi==1.15.1
# via argon2-cffi-bindings
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.2.0
# via
# -c requirements/prod.txt
# aiohttp
# requests
click==8.1.7
# via
# -c requirements/prod.txt
# nltk
# wandb
comm==0.1.4
# via
# ipykernel
# ipywidgets
contourpy==1.1.0
# via
# -c requirements/prod.txt
# matplotlib
coverage[toml]==6.4
# via
# -r requirements/dev.in
# pytest-cov
cycler==0.11.0
# via
# -c requirements/prod.txt
# matplotlib
debugpy==1.6.7.post1
# via ipykernel
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via
# -r requirements/dev.in
# nbconvert
distlib==0.3.7
# via virtualenv
docker-pycreds==0.4.0
# via wandb
entrypoints==0.4
# via
# jupyter-client
# nbconvert
exceptiongroup==1.1.3
# via
# -c requirements/prod.txt
# anyio
executing==1.2.0
# via stack-data
fastjsonschema==2.18.0
# via nbformat
filelock==3.12.2
# via
# -c requirements/prod.txt
# torch
# triton
# virtualenv
fonttools==4.42.1
# via
# -c requirements/prod.txt
# matplotlib
fqdn==1.5.1
# via jsonschema
frozenlist==1.4.0
# via
# -c requirements/prod.txt
# aiohttp
# aiosignal
fsspec[http]==2023.6.0
# via
# -c requirements/prod.txt
# pytorch-lightning
# torch
gitdb==4.0.10
# via gitpython
gitpython==3.1.32
# via wandb
google-auth==2.22.0
# via
# google-auth-oauthlib
# tensorboard
google-auth-oauthlib==1.0.0
# via tensorboard
grpcio==1.57.0
# via tensorboard
identify==2.5.27
# via pre-commit
idna==3.4
# via
# -c requirements/prod.txt
# anyio
# jsonschema
# requests
# yarl
iniconfig==2.0.0
# via pytest
ipykernel==6.25.1
# via
# nbclassic
# notebook
ipython==8.14.0
# via
# ipykernel
# ipywidgets
ipython-genutils==0.2.0
# via
# nbclassic
# notebook
ipywidgets==8.1.0
# via -r requirements/dev.in
isoduration==20.11.0
# via jsonschema
itermplot==0.331
# via -r requirements/dev.in
jedi==0.19.0
# via ipython
jinja2==3.1.2
# via
# -c requirements/prod.txt
# jupyter-server
# nbclassic
# nbconvert
# notebook
# torch
joblib==1.3.2
# via nltk
jsonpointer==2.4
# via jsonschema
jsonschema[format-nongpl]==4.19.0
# via
# -c requirements/prod.txt
# jupyter-events
# nbformat
jsonschema-specifications==2023.7.1
# via
# -c requirements/prod.txt
# jsonschema
jupyter-client==7.4.9
# via
# ipykernel
# jupyter-server
# nbclassic
# nbclient
# notebook
jupyter-core==5.3.1
# via
# ipykernel
# jupyter-client
# jupyter-server
# nbclassic
# nbclient
# nbconvert
# nbformat
# notebook
jupyter-events==0.7.0
# via jupyter-server
jupyter-server==2.7.2
# via
# nbclassic
# notebook-shim
jupyter-server-terminals==0.4.4
# via jupyter-server
jupyterlab-pygments==0.2.2
# via nbconvert
jupyterlab-widgets==3.0.8
# via ipywidgets
kiwisolver==1.4.5
# via
# -c requirements/prod.txt
# matplotlib
lxml==4.9.3
# via nbconvert
markdown==3.4.4
# via tensorboard
markupsafe==2.1.3
# via
# -c requirements/prod.txt
# jinja2
# nbconvert
# werkzeug
matplotlib==3.7.2
# via
# -c requirements/prod.txt
# -r requirements/dev.in
# itermplot
# seaborn
matplotlib-inline==0.1.6
# via
# ipykernel
# ipython
mistune==0.8.4
# via nbconvert
mpmath==1.3.0
# via
# -c requirements/prod.txt
# sympy
multidict==6.0.4
# via
# -c requirements/prod.txt
# aiohttp
# yarl
nbclassic==1.0.0
# via notebook
nbclient==0.8.0
# via nbconvert
nbconvert==6.5.4
# via
# -r requirements/dev.in
# jupyter-server
# nbclassic
# notebook
nbformat==5.9.2
# via
# jupyter-server
# nbclassic
# nbclient
# nbconvert
# notebook
nest-asyncio==1.5.7
# via
# ipykernel
# jupyter-client
# nbclassic
# notebook
networkx==3.1
# via
# -c requirements/prod.txt
# torch
nltk==3.8.1
# via -r requirements/dev.in
nodeenv==1.8.0
# via pre-commit
notebook==6.5.5
# via -r requirements/dev.in
notebook-shim==0.2.3
# via nbclassic
numpy==1.25.2
# via
# -c requirements/prod.txt
# contourpy
# itermplot
# matplotlib
# pandas
# pytorch-lightning
# scipy
# seaborn
# tensorboard
# torchmetrics
nvidia-cublas-cu12==12.1.3.1
# via
# -c requirements/prod.txt
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via
# -c requirements/prod.txt
# torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via
# -c requirements/prod.txt
# torch
nvidia-cuda-runtime-cu12==12.1.105
# via
# -c requirements/prod.txt
# torch
nvidia-cudnn-cu12==8.9.2.26
# via
# -c requirements/prod.txt
# torch
nvidia-cufft-cu12==11.0.2.54
# via
# -c requirements/prod.txt
# torch
nvidia-curand-cu12==10.3.2.106
# via
# -c requirements/prod.txt
# torch
nvidia-cusolver-cu12==11.4.5.107
# via
# -c requirements/prod.txt
# torch
nvidia-cusparse-cu12==12.1.0.106
# via
# -c requirements/prod.txt
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.18.1
# via
# -c requirements/prod.txt
# torch
nvidia-nvjitlink-cu12==12.3.101
# via
# -c requirements/prod.txt
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via
# -c requirements/prod.txt
# torch
oauthlib==3.2.2
# via requests-oauthlib
overrides==7.4.0
# via jupyter-server
packaging==23.1
# via
# -c requirements/prod.txt
# ipykernel
# jupyter-server
# matplotlib
# nbconvert
# pytest
# pytorch-lightning
# torchmetrics
pandas==2.0.3
# via
# -c requirements/prod.txt
# seaborn
pandocfilters==1.5.0
# via nbconvert
parso==0.8.3
# via jedi
pathtools==0.1.2
# via wandb
pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
pillow==9.4.0
# via
# -c requirements/prod.txt
# matplotlib
platformdirs==3.10.0
# via
# jupyter-core
# virtualenv
pluggy==1.2.0
# via pytest
pre-commit==3.3.3
# via -r requirements/dev.in
prometheus-client==0.17.1
# via
# jupyter-server
# nbclassic
# notebook
promise==2.3
# via wandb
prompt-toolkit==3.0.39
# via ipython
protobuf==3.20.3
# via
# tensorboard
# wandb
psutil==5.9.5
# via
# ipykernel
# wandb
ptyprocess==0.7.0
# via
# pexpect
# terminado
pure-eval==0.2.2
# via stack-data
py==1.11.0
# via pytest
pyasn1==0.5.0
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.3.0
# via google-auth
pycparser==2.21
# via cffi
pydeprecate==0.3.2
# via
# pytorch-lightning
# torchmetrics
pygments==2.16.1
# via
# ipython
# nbconvert
pyparsing==3.0.9
# via
# -c requirements/prod.txt
# matplotlib
pytest==7.1.1
# via
# -r requirements/dev.in
# pytest-cov
pytest-cov==3.0.0
# via -r requirements/dev.in
python-dateutil==2.8.2
# via
# -c requirements/prod.txt
# jupyter-client
# matplotlib
# pandas
# wandb
python-json-logger==2.0.7
# via jupyter-events
pytorch-lightning==1.6.3
# via -r requirements/dev.in
pytz==2023.3
# via
# -c requirements/prod.txt
# pandas
pyyaml==6.0.1
# via
# -c requirements/prod.txt
# jupyter-events
# pre-commit
# pytorch-lightning
# wandb
pyzmq==24.0.1
# via
# ipykernel
# jupyter-client
# jupyter-server
# nbclassic
# notebook
referencing==0.30.2
# via
# -c requirements/prod.txt
# jsonschema
# jsonschema-specifications
# jupyter-events
regex==2023.8.8
# via
# -c requirements/prod.txt
# nltk
requests==2.31.0
# via
# -c requirements/prod.txt
# fsspec
# requests-oauthlib
# tensorboard
# wandb
requests-oauthlib==1.3.1
# via google-auth-oauthlib
rfc3339-validator==0.1.4
# via
# jsonschema
# jupyter-events
rfc3986-validator==0.1.1
# via
# jsonschema
# jupyter-events
rpds-py==0.9.2
# via
# -c requirements/prod.txt
# jsonschema
# referencing
rsa==4.9
# via google-auth
scipy==1.11.2
# via -r requirements/dev.in
seaborn==0.13.1
# via -r requirements/dev.in
send2trash==1.8.2
# via
# jupyter-server
# nbclassic
# notebook
sentry-sdk==1.29.2
# via wandb
setproctitle==1.3.2
# via wandb
shortuuid==1.0.11
# via wandb
six==1.16.0
# via
# -c requirements/prod.txt
# asttokens
# bleach
# docker-pycreds
# google-auth
# itermplot
# promise
# python-dateutil
# rfc3339-validator
# wandb
smmap==5.0.0
# via gitdb
sniffio==1.3.0
# via
# -c requirements/prod.txt
# anyio
soupsieve==2.4.1
# via beautifulsoup4
stack-data==0.6.2
# via ipython
sympy==1.12
# via
# -c requirements/prod.txt
# torch
tensorboard==2.14.0
# via pytorch-lightning
tensorboard-data-server==0.7.1
# via tensorboard
terminado==0.17.1
# via
# jupyter-server
# jupyter-server-terminals
# nbclassic
# notebook
tinycss2==1.2.1
# via nbconvert
toml==0.10.2
# via -r requirements/dev.in
tomli==2.0.1
# via
# coverage
# pytest
torch==2.1.1
# via
# -c requirements/prod.txt
# pytorch-lightning
# torchmetrics
torchmetrics==0.7.3
# via
# -r requirements/dev.in
# pytorch-lightning
tornado==6.3.3
# via
# -r requirements/dev.in
# ipykernel
# jupyter-client
# jupyter-server
# nbclassic
# notebook
# terminado
tqdm==4.66.1
# via
# -c requirements/prod.txt
# nltk
# pytorch-lightning
traitlets==5.9.0
# via
# comm
# ipykernel
# ipython
# ipywidgets
# jupyter-client
# jupyter-core
# jupyter-events
# jupyter-server
# matplotlib-inline
# nbclassic
# nbclient
# nbconvert
# nbformat
# notebook
triton==2.1.0
# via
# -c requirements/prod.txt
# torch
typing-extensions==4.7.1
# via
# -c requirements/prod.txt
# pytorch-lightning
# torch
tzdata==2023.3
# via
# -c requirements/prod.txt
# pandas
uri-template==1.3.0
# via jsonschema
urllib3==1.26.16
# via
# -c requirements/prod.txt
# google-auth
# requests
# sentry-sdk
virtualenv==20.24.3
# via pre-commit
wandb==0.12.17
# via -r requirements/dev.in
wcwidth==0.2.6
# via prompt-toolkit
webcolors==1.13
# via jsonschema
webencodings==0.5.1
# via
# bleach
# tinycss2
websocket-client==1.6.2
# via jupyter-server
werkzeug==2.3.7
# via tensorboard
wheel==0.41.2
# via tensorboard
widgetsnbextension==4.0.8
# via ipywidgets
yarl==1.9.2
# via
# -c requirements/prod.txt
# aiohttp
# The following packages are considered to be unsafe in a requirements file:
# setuptools
================================================
FILE: requirements/prod.in
================================================
h5py
importlib-metadata>=4.4
numpy
pyngrok>=6.0,<6.1
requests
smart_open[s3]
tqdm
# versioned for stability
gantry==0.4.9
gradio==3.40.1
# versioned to match Google Colab up to minor
Jinja2>=3.1,<3.2
pillow>=9.4,<9.5
torch>=2.1,<2.2
torchvision>=0.16,<0.17
================================================
FILE: requirements/prod.txt
================================================
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile requirements/prod.in
#
aiofiles==23.2.1
# via gradio
aiohttp==3.8.5
# via gradio
aiosignal==1.3.1
# via aiohttp
altair==5.0.1
# via gradio
annotated-types==0.5.0
# via pydantic
anyio==3.7.1
# via
# httpcore
# starlette
async-timeout==4.0.3
# via aiohttp
attrs==23.1.0
# via
# aiohttp
# jsonschema
# referencing
boto3==1.28.34
# via
# boto3-extensions
# smart-open
boto3-extensions==0.20.0
# via gantry
botocore==1.31.34
# via
# boto3
# boto3-extensions
# s3transfer
cachetools==4.2.4
# via gantry
certifi==2023.7.22
# via
# httpcore
# httpx
# requests
charset-normalizer==3.2.0
# via
# aiohttp
# requests
click==8.1.7
# via
# gantry
# uvicorn
click-spinner==0.1.10
# via gantry
colorama==0.4.6
# via
# gantry
# halo
# log-symbols
contourpy==1.1.0
# via matplotlib
cycler==0.11.0
# via matplotlib
dateparser==1.1.8
# via gantry
exceptiongroup==1.1.3
# via anyio
fastapi==0.101.1
# via gradio
ffmpy==0.3.1
# via gradio
filelock==3.12.2
# via
# huggingface-hub
# torch
# triton
fonttools==4.42.1
# via matplotlib
frozenlist==1.4.0
# via
# aiohttp
# aiosignal
fsspec==2023.6.0
# via
# gradio-client
# huggingface-hub
# torch
gantry==0.4.9
# via -r requirements/prod.in
gradio==3.40.1
# via -r requirements/prod.in
gradio-client==0.5.0
# via gradio
h11==0.14.0
# via
# httpcore
# uvicorn
h5py==3.9.0
# via -r requirements/prod.in
halo==0.0.31
# via gantry
httpcore==0.17.3
# via httpx
httpx==0.24.1
# via
# gradio
# gradio-client
huggingface-hub==0.16.4
# via
# gradio
# gradio-client
idna==3.4
# via
# anyio
# httpx
# requests
# yarl
importlib-metadata==6.8.0
# via -r requirements/prod.in
importlib-resources==6.0.1
# via gradio
isodate==0.6.1
# via gantry
jinja2==3.1.2
# via
# -r requirements/prod.in
# altair
# gradio
# torch
jmespath==1.0.1
# via
# boto3
# botocore
jsonschema==4.19.0
# via altair
jsonschema-specifications==2023.7.1
# via jsonschema
kiwisolver==1.4.5
# via matplotlib
linkify-it-py==2.0.2
# via markdown-it-py
log-symbols==0.0.14
# via halo
markdown-it-py[linkify]==2.2.0
# via
# gradio
# mdit-py-plugins
markupsafe==2.1.3
# via
# gradio
# jinja2
marshmallow==3.20.1
# via
# gantry
# marshmallow-oneofschema
marshmallow-oneofschema==3.0.1
# via gantry
matplotlib==3.7.2
# via gradio
mdit-py-plugins==0.3.3
# via gradio
mdurl==0.1.2
# via markdown-it-py
monotonic==1.6
# via gantry
mpmath==1.3.0
# via sympy
multidict==6.0.4
# via
# aiohttp
# yarl
networkx==3.1
# via torch
numpy==1.25.2
# via
# -r requirements/prod.in
# altair
# contourpy
# gantry
# gradio
# h5py
# matplotlib
# pandas
# torchvision
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.18.1
# via torch
nvidia-nvjitlink-cu12==12.3.101
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
orjson==3.9.5
# via gradio
packaging==23.1
# via
# gradio
# gradio-client
# huggingface-hub
# marshmallow
# matplotlib
pandas==2.0.3
# via
# altair
# gantry
# gradio
pillow==9.4.0
# via
# -r requirements/prod.in
# gradio
# matplotlib
# torchvision
pydantic==2.3.0
# via
# fastapi
# gradio
pydantic-core==2.6.3
# via pydantic
pydub==0.25.1
# via gradio
pyngrok==6.0.0
# via -r requirements/prod.in
pyparsing==3.0.9
# via matplotlib
python-dateutil==2.8.2
# via
# botocore
# dateparser
# gantry
# matplotlib
# pandas
python-multipart==0.0.6
# via gradio
pytz==2023.3
# via
# dateparser
# pandas
pyyaml==6.0.1
# via
# gantry
# gradio
# huggingface-hub
# pyngrok
referencing==0.30.2
# via
# jsonschema
# jsonschema-specifications
regex==2023.8.8
# via dateparser
requests==2.31.0
# via
# -r requirements/prod.in
# gantry
# gradio
# gradio-client
# huggingface-hub
# torchvision
rpds-py==0.9.2
# via
# jsonschema
# referencing
s3transfer==0.6.2
# via boto3
semantic-version==2.10.0
# via gradio
six==1.16.0
# via
# halo
# isodate
# python-dateutil
smart-open[s3]==6.3.0
# via -r requirements/prod.in
sniffio==1.3.0
# via
# anyio
# httpcore
# httpx
spinners==0.0.24
# via halo
starlette==0.27.0
# via fastapi
sympy==1.12
# via torch
tabulate==0.9.0
# via gantry
termcolor==2.3.0
# via halo
toolz==0.12.0
# via altair
torch==2.1.1
# via
# -r requirements/prod.in
# torchvision
torchvision==0.16.1
# via -r requirements/prod.in
tqdm==4.66.1
# via
# -r requirements/prod.in
# gantry
# huggingface-hub
triton==2.1.0
# via torch
typeguard==2.13.3
# via gantry
typing-extensions==4.7.1
# via
# altair
# fastapi
# gantry
# gradio
# gradio-client
# huggingface-hub
# pydantic
# pydantic-core
# torch
# uvicorn
tzdata==2023.3
# via pandas
tzlocal==5.0.1
# via dateparser
uc-micro-py==1.0.2
# via linkify-it-py
urllib3==1.26.16
# via
# botocore
# requests
uvicorn==0.23.2
# via gradio
websockets==11.0.3
# via
# gradio
# gradio-client
yarl==1.9.2
# via aiohttp
zipp==3.16.2
# via importlib-metadata
================================================
FILE: setup/readme.md
================================================
# Setup
Deep learning requires access to accelerated computation hardware.
Most commonly, those are NVIDIA GPUs or Google TPUs.
If you have access to a computer that has an NVIDIA GPU and runs Linux, you're welcome to [set it up](#Local) for local use.
If you don't, you can get free compute with [Google Colab](#Colab).
## Colab
Google Colab is a great way to get access to fast GPUs for free.
All you need is a Google account.
The preferred way to interact with the labs on Colab is just to click on badges like this one:
All setup is handled automatically,
so you can immediately start working on the labs.
But if you just want to use the codebase, then
go to [https://colab.research.google.com](https://colab.research.google.com)
and create a new notebook.
Connect your new notebook to a GPU runtime by doing Runtime > Change Runtime type > GPU.

Now, run `!nvidia-smi` in the first cell (press Shift+Enter to run a cell).
You should see a table showing your precious GPU :)
Now, paste the following into a cell and run it:
```py
# FSDL 2022 Setup
lab_idx = None
if "bootstrap" not in locals() or bootstrap.run:
# path management for Python
pythonpath, = !echo $PYTHONPATH
if "." not in pythonpath.split(":"):
pythonpath = ".:" + pythonpath
%env PYTHONPATH={pythonpath}
!echo $PYTHONPATH
# get both Colab and local notebooks into the same state
!wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py
import bootstrap
# change into the lab directory
bootstrap.change_to_lab_dir(lab_idx=lab_idx)
# allow "hot-reloading" of modules
%load_ext autoreload
%autoreload 2
bootstrap.run = False # change to True re-run setup
!pwd
%ls
```
The bootstrap script will
check out our lab repository,
`cd` into it,
and install required packages.
It also allows Python to find packages in the current working directory.
From there, you can `%cd` into a lab folder
to play around with the codebase for that lab,
either by directly writing Python,
e.g. `import text_recognizer`,
or by running shell commands, like
`!python training/run_experiment.py`.
### Colab Pro
You may be interested in signing up for [Colab Pro](https://colab.research.google.com/signup).
For $10/month, you get priority access to faster GPUs (e.g. [P100 vs K80](https://www.xcelerit.com/computing-benchmarks/insights/nvidia-p100-vs-k80-gpu/)) and TPUs, a 24h rather than 12h maximum runtime, and more RAM.
## Local
Setting up a machine that you can sit in front of or SSH into is easy.
### Watch a walkthrough video [here](https://fsdl.me/2022-local-setup-video).
If you get stuck, it's better to at least [get started with the labs on Colab](https://fsdl.me/lab00-colab), where setup is just a single click, rather than getting frustrated and burning out on annoying environment management, networking, and systems administration issues that aren't as relevant to making ML-powered products.
### Summary
- `environment.yml` specifies Python and optionally CUDA/CUDNN
- `make conda-update` creates/updates a virtual environment
- `conda activate fsdl-text-recognizer-2022` activates the virtual environment
- `requirements/prod.in` and `requirements/dev.in` specify core Python packages in that environment
- `make pip-tools` resolves all other Python dependencies and installs them
- `export PYTHONPATH=.:$PYTHONPATH` makes the current directory visible on your Python path -- add it to your `~/.bashrc` and `source ~/.bashrc`
### 1. Check out the repo
```
git clone https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs.git
cd fsdl-text-recognizer-2022-labs
```
### 2. Set up the Python environment
We use
[`conda`](https://docs.conda.io/en/latest/miniconda.html)
for managing Python and CUDA versions, and
[`pip-tools`](https://github.com/jazzband/pip-tools)
for managing Python package dependencies.
We add a `Makefile` for making setup dead-simple.
#### First: Install the Python + CUDA environment using Anaconda
Conda is an open-source package management system and environment management system that runs on Windows, macOS, and Linux.
It is most closely associated with Python, but
[in fact it can manage more than just Python environments](https://jakevdp.github.io/blog/2016/08/25/conda-myths-and-misconceptions/).
To install `conda`, follow instructions at https://conda.io/projects/conda/en/latest/user-guide/install/linux.html.
Conda will install the appropriate version of Python for you in the project environment,
so it doesn't matter which installer you choose.
In the project we use the version of Python used in Google Colab,
which at time of writing is Python 3.10.
Note that you will likely need to close and re-open your terminal.
Afterwards, you should have ability to run the `conda` command in your terminal.
Run `make conda-update` to create an environment called `fsdl-text-recognizer-2022`, as defined in `environment.yml`.
This environment will provide us with the right Python version as well as the CUDA and CUDNN libraries.
If you edit `environment.yml`, just run `make conda-update` again to get the latest changes.
Next, activate the conda environment.
```sh
conda activate fsdl-text-recognizer-2022
```
**IMPORTANT**: every time you work in this directory, make sure to start your session with `conda activate fsdl-text-recognizer-2022`.
#### Next: install Python packages
Next, install all necessary Python packages by running `make pip-tools`
Using `pip-tools` lets us do three nice things:
1. Separate out dev from production dependencies (`dev.in` vs `prod.in`).
2. Have a lockfile of exact versions for all dependencies (the auto-generated `dev.txt` and `prod.txt`).
3. Allow us to easily deploy to targets that don't support the `conda` environment, like Colab.
#### Set PYTHONPATH
Last, run `export PYTHONPATH=.` before executing any commands later on, or you will get errors like this:
```python
ModuleNotFoundError: No module named 'text_recognizer'
```
In order to not have to set `PYTHONPATH` in every terminal you open, just add that line as the last line of the `~/.bashrc` file using a text editor of your choice (e.g. `nano ~/.bashrc`) or by concatenating with `>>`
```bash
echo "export PYTHONPATH=.:$PYTHONPATH" >> ~/.bashrc
```