Repository: full-stack-deep-learning/fsdl-text-recognizer-2022-labs Branch: main Commit: 485b963cb41d Files: 424 Total size: 3.9 MB Directory structure: gitextract_8c13vqgo/ ├── .flake8 ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ └── this-repository-is-automatically-generated--don-t-open-issues-here-.md │ └── pull_request_template.md ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.txt ├── Makefile ├── data/ │ └── raw/ │ ├── emnist/ │ │ ├── metadata.toml │ │ └── readme.md │ ├── fsdl_handwriting/ │ │ ├── fsdl_handwriting.jsonl │ │ ├── manifest.csv │ │ ├── metadata.toml │ │ └── readme.md │ └── iam/ │ ├── metadata.toml │ └── readme.md ├── environment.yml ├── lab01/ │ ├── notebooks/ │ │ └── lab01_pytorch.ipynb │ └── text_recognizer/ │ ├── __init__.py │ ├── data/ │ │ └── util.py │ ├── metadata/ │ │ ├── mnist.py │ │ └── shared.py │ ├── models/ │ │ ├── __init__.py │ │ └── mlp.py │ └── util.py ├── lab02/ │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ └── lab02b_cnn.ipynb │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ └── base.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ └── mlp.py │ │ ├── stems/ │ │ │ └── image.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── run_experiment.py │ └── util.py ├── lab03/ │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ ├── lab02b_cnn.ipynb │ │ └── lab03_transformers.ipynb │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── metrics.py │ │ │ ├── transformer.py │ │ │ └── util.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ ├── mlp.py │ │ │ ├── resnet_transformer.py │ │ │ └── transformer_util.py │ │ ├── stems/ │ │ │ ├── image.py │ │ │ └── paragraph.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── run_experiment.py │ └── util.py ├── lab04/ │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ ├── lab02b_cnn.ipynb │ │ ├── lab03_transformers.ipynb │ │ └── lab04_experiments.ipynb │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── imtotext.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ └── util.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── metrics.py │ │ │ ├── transformer.py │ │ │ └── util.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ ├── line_cnn_transformer.py │ │ │ ├── mlp.py │ │ │ ├── resnet_transformer.py │ │ │ └── transformer_util.py │ │ ├── stems/ │ │ │ ├── image.py │ │ │ ├── line.py │ │ │ └── paragraph.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── run_experiment.py │ └── util.py ├── lab05/ │ ├── .flake8 │ ├── .github/ │ │ └── workflows/ │ │ └── pre-commit.yml │ ├── .pre-commit-config.yaml │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ ├── lab02b_cnn.ipynb │ │ ├── lab03_transformers.ipynb │ │ ├── lab04_experiments.ipynb │ │ └── lab05_troubleshooting.ipynb │ ├── tasks/ │ │ └── lint.sh │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── imtotext.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ └── util.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── fake_images.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── metrics.py │ │ │ ├── transformer.py │ │ │ └── util.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ ├── line_cnn_transformer.py │ │ │ ├── mlp.py │ │ │ ├── resnet_transformer.py │ │ │ └── transformer_util.py │ │ ├── stems/ │ │ │ ├── image.py │ │ │ ├── line.py │ │ │ └── paragraph.py │ │ ├── tests/ │ │ │ ├── test_callback_utils.py │ │ │ └── test_iam.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── run_experiment.py │ ├── tests/ │ │ ├── test_memorize_iam.sh │ │ └── test_run_experiment.sh │ └── util.py ├── lab06/ │ ├── .flake8 │ ├── .github/ │ │ └── workflows/ │ │ └── pre-commit.yml │ ├── .pre-commit-config.yaml │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ ├── lab02b_cnn.ipynb │ │ ├── lab03_transformers.ipynb │ │ ├── lab04_experiments.ipynb │ │ ├── lab05_troubleshooting.ipynb │ │ └── lab06_data.ipynb │ ├── tasks/ │ │ └── lint.sh │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── imtotext.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ └── util.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── fake_images.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_original_and_synthetic_paragraphs.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── iam_synthetic_paragraphs.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── metrics.py │ │ │ ├── transformer.py │ │ │ └── util.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── iam_synthetic_paragraphs.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ ├── line_cnn_transformer.py │ │ │ ├── mlp.py │ │ │ ├── resnet_transformer.py │ │ │ └── transformer_util.py │ │ ├── stems/ │ │ │ ├── image.py │ │ │ ├── line.py │ │ │ └── paragraph.py │ │ ├── tests/ │ │ │ ├── test_callback_utils.py │ │ │ └── test_iam.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── run_experiment.py │ ├── tests/ │ │ ├── test_memorize_iam.sh │ │ └── test_run_experiment.sh │ └── util.py ├── lab07/ │ ├── .flake8 │ ├── .github/ │ │ └── workflows/ │ │ └── pre-commit.yml │ ├── .pre-commit-config.yaml │ ├── api_serverless/ │ │ ├── Dockerfile │ │ ├── __init__.py │ │ └── api.py │ ├── app_gradio/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── __init__.py │ │ ├── app.py │ │ └── tests/ │ │ └── test_app.py │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ ├── lab02b_cnn.ipynb │ │ ├── lab03_transformers.ipynb │ │ ├── lab04_experiments.ipynb │ │ ├── lab05_troubleshooting.ipynb │ │ ├── lab06_data.ipynb │ │ └── lab07_deployment.ipynb │ ├── tasks/ │ │ └── lint.sh │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── imtotext.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ └── util.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── fake_images.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_original_and_synthetic_paragraphs.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── iam_synthetic_paragraphs.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── metrics.py │ │ │ ├── transformer.py │ │ │ └── util.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── iam_synthetic_paragraphs.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ ├── line_cnn_transformer.py │ │ │ ├── mlp.py │ │ │ ├── resnet_transformer.py │ │ │ └── transformer_util.py │ │ ├── paragraph_text_recognizer.py │ │ ├── stems/ │ │ │ ├── image.py │ │ │ ├── line.py │ │ │ └── paragraph.py │ │ ├── tests/ │ │ │ ├── test_callback_utils.py │ │ │ └── test_iam.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── cleanup_artifacts.py │ ├── run_experiment.py │ ├── stage_model.py │ ├── tests/ │ │ ├── test_memorize_iam.sh │ │ ├── test_model_development.sh │ │ └── test_run_experiment.sh │ └── util.py ├── lab08/ │ ├── .flake8 │ ├── .github/ │ │ └── workflows/ │ │ └── pre-commit.yml │ ├── .pre-commit-config.yaml │ ├── api_serverless/ │ │ ├── Dockerfile │ │ ├── __init__.py │ │ └── api.py │ ├── app_gradio/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── __init__.py │ │ ├── app.py │ │ ├── flagging.py │ │ ├── s3_util.py │ │ └── tests/ │ │ └── test_app.py │ ├── notebooks/ │ │ ├── lab01_pytorch.ipynb │ │ ├── lab02a_lightning.ipynb │ │ ├── lab02b_cnn.ipynb │ │ ├── lab03_transformers.ipynb │ │ ├── lab04_experiments.ipynb │ │ ├── lab05_troubleshooting.ipynb │ │ ├── lab06_data.ipynb │ │ ├── lab07_deployment.ipynb │ │ └── lab08_monitoring.ipynb │ ├── tasks/ │ │ └── lint.sh │ ├── text_recognizer/ │ │ ├── __init__.py │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── imtotext.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ └── util.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base_data_module.py │ │ │ ├── emnist.py │ │ │ ├── emnist_essentials.json │ │ │ ├── emnist_lines.py │ │ │ ├── fake_images.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_original_and_synthetic_paragraphs.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── iam_synthetic_paragraphs.py │ │ │ ├── mnist.py │ │ │ ├── sentence_generator.py │ │ │ └── util.py │ │ ├── lit_models/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── metrics.py │ │ │ ├── transformer.py │ │ │ └── util.py │ │ ├── metadata/ │ │ │ ├── emnist.py │ │ │ ├── emnist_lines.py │ │ │ ├── iam.py │ │ │ ├── iam_lines.py │ │ │ ├── iam_paragraphs.py │ │ │ ├── iam_synthetic_paragraphs.py │ │ │ ├── mnist.py │ │ │ └── shared.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cnn.py │ │ │ ├── line_cnn.py │ │ │ ├── line_cnn_simple.py │ │ │ ├── line_cnn_transformer.py │ │ │ ├── mlp.py │ │ │ ├── resnet_transformer.py │ │ │ └── transformer_util.py │ │ ├── paragraph_text_recognizer.py │ │ ├── stems/ │ │ │ ├── image.py │ │ │ ├── line.py │ │ │ └── paragraph.py │ │ ├── tests/ │ │ │ ├── test_callback_utils.py │ │ │ └── test_iam.py │ │ └── util.py │ └── training/ │ ├── __init__.py │ ├── cleanup_artifacts.py │ ├── run_experiment.py │ ├── stage_model.py │ ├── tests/ │ │ ├── test_memorize_iam.sh │ │ ├── test_model_development.sh │ │ └── test_run_experiment.sh │ └── util.py ├── overview.ipynb ├── pyproject.toml ├── readme.md ├── requirements/ │ ├── dev-lint.in │ ├── dev.in │ ├── dev.txt │ ├── prod.in │ └── prod.txt └── setup/ └── readme.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .flake8 ================================================ [flake8] select = ANN,B,B9,BLK,C,D,E,F,I,S,W # only check selected error codes max-complexity = 12 # C9 - flake8 McCabe Complexity checker -- threshold max-line-length = 120 # E501 - flake8 -- line length too long, actually handled by black extend-ignore = # E W - flake8 PEP style check E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks # S - flake8-bandit safety check S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password # ANN - flake8-annotations type annotation check ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some # D1 - flake8-docstrings docstring style check D100,D102,D103,D104,D105, # missing docstrings # D2 D4 - flake8-docstrings docstring style check D200,D205,D400,D401, # whitespace issues and first line content # DAR - flake8-darglint docstring correctness check DAR103, # mismatched or missing type in docstring application-import-names = app_gradio,text_recognizer,tests,training # flake8-import-order: which names are first party? import-order-style = google # flake8-import-order: which import order style guide do we use? docstring-convention = numpy # flake8-docstrings: which docstring style guide do we use? strictness = short # darglint: how "strict" are we with docstring completeness? docstring-style = numpy # darglint: which docstring style guide do we use? suppress-none-returning = true # flake8-annotations: do we allow un-annotated Nones in returns? mypy-init-return = true # flake8-annotations: do we allow init to have no return annotation? per-file-ignores = # list of case-by-case ignores, see files for details */__init__.py:F401,I */data/*.py:DAR data/*.py:F,I *text_recognizer/util.py:DAR101,F401 *training/run_experiment.py:I202 *app_gradio/app.py:I202 ================================================ FILE: .github/ISSUE_TEMPLATE/this-repository-is-automatically-generated--don-t-open-issues-here-.md ================================================ --- name: This repository is automatically generated! Don't open issues here. about: Open issues in the generating repo instead, at https://fsdl.me/2022-repo. title: '' labels: '' assignees: '' --- Thanks for your interest in contributing! This repository is automatically generated from a source repo, so the preferred place for issues and the only place for PRs is there. So please open your issues [there](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022). Looking forward to hearing from you! ================================================ FILE: .github/pull_request_template.md ================================================ Thanks for your interest in contributing! This repository is automatically generated from [a source repo](https://fsdl.me/2022-repo), so the preferred place for issues and the only place for PRs is there. So please open your issues [there](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022). Looking forward to hearing from you! ================================================ FILE: .gitignore ================================================ # Data data/downloaded data/processed data/interim # Editors .vscode *.sw? *~ # Node node_modules # Python __pycache__ .pytest_cache # notebooks .ipynb_checkpoints *.nbconvert*.ipynb .notebook_test.sh # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # logging wandb *.pt *.ckpt lightning_logs/ logs */training/logs */training/*sweep.yaml flagged # Misc .aws/credentials .DS_Store .env _labs .mypy_cache lab9/requirements.txt .coverage* /requirements.txt requirements/dev-lint.txt bootstrap.py **/fixme.py .server.env ================================================ FILE: .pre-commit-config.yaml ================================================ repos: # a set of useful Python-based pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: # list of definitions and supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace # removes any whitespace at the ends of lines - id: check-toml # check toml syntax by loading all toml files - id: check-yaml # check yaml syntax by loading all yaml files - id: check-json # check-json syntax by loading all json files - id: check-merge-conflict # check for files with merge conflict strings args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge - id: check-added-large-files # check that no "large" files have been added args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server - id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.) - id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.) # black python autoformatting - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black # additional configuration of black in pyproject.toml # flake8 python linter with all the fixins - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08) additional_dependencies: [ flake8-bandit, flake8-bugbear, flake8-docstrings, flake8-import-order, darglint, mypy, pycodestyle, pydocstyle] args: ["--config", ".flake8"] # additional configuration of flake8 and extensions in .flake8 # shellcheck-py for linting shell files - repo: https://github.com/shellcheck-py/shellcheck-py rev: v0.8.0.4 hooks: - id: shellcheck ================================================ FILE: LICENSE.txt ================================================ MIT License Copyright (c) 2022 Full Stack Deep Learning, LLC Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ # Arcane incantation to print all the other targets, from https://stackoverflow.com/a/26339924 help: @$(MAKE) -pRrq -f $(lastword $(MAKEFILE_LIST)) : 2>/dev/null | awk -v RS= -F: '/^# File/,/^# Finished Make data base/ {if ($$1 !~ "^[#.]") {print $$1}}' | sort | egrep -v -e '^[^[:alnum:]]' -e '^$@$$' # Install exact Python and CUDA versions conda-update: conda env update --prune -f environment.yml echo "!!!RUN THE conda activate COMMAND ABOVE RIGHT NOW!!!" # Compile and install exact pip packages pip-tools: pip install pip-tools==6.13.0 setuptools==67.7.2 pip-compile requirements/prod.in && pip-compile requirements/dev.in pip-sync requirements/prod.txt requirements/dev.txt # Compile and install the requirements for local linting (optional) pip-tools-lint: pip install pip-tools==6.13.0 setuptools==67.7.2 pip-compile requirements/prod.in && pip-compile requirements/dev.in && pip-compile requirements/dev-lint.in pip-sync requirements/prod.txt requirements/dev.txt requirements/dev-lint.txt # Bump versions of transitive dependencies pip-tools-upgrade: pip install pip-tools==6.13.0 setuptools==67.7.2 pip-compile --upgrade requirements/prod.in && pip-compile --upgrade requirements/dev.in && pip-compile --upgrade requirements/dev-lint.in # Example training command train-mnist-cnn-ddp: python training/run_experiment.py --max_epochs=10 --gpus=-1 --accelerator=ddp --num_workers=20 --data_class=MNIST --model_class=CNN # Lint lint: tasks/lint.sh # Test notebooks in source repo test-notebooks: tasks/notebook_test.sh $(SELECT_BY) # Test all lab notebooks from the folder for provided lab InDeX test-labs-up-to: cd lab$(IDX) && ./.notebook_test.sh # Test only the notebooks for the provided lab InDeX test-lab: cd lab$(IDX) && ./.notebook_test.sh $(IDX) ================================================ FILE: data/raw/emnist/metadata.toml ================================================ filename = 'matlab.zip' sha256 = 'e1fa805cdeae699a52da0b77c2db17f6feb77eed125f9b45c022e7990444df95' url = 'https://s3-us-west-2.amazonaws.com/fsdl-public-assets/matlab.zip' ================================================ FILE: data/raw/emnist/readme.md ================================================ # EMNIST dataset The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset Original url is http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/matlab.zip We uploaded the same file to our S3 bucket for faster download. ================================================ FILE: data/raw/fsdl_handwriting/fsdl_handwriting.jsonl ================================================ {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/85a46e90-8af1-48fe-9e75-d18f7f05d6e9___a01-000u.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1422924901185771,0.18948824343015216],[0.875494071146245,0.18948824343015216],[0.875494071146245,0.25034578146611347],[0.1422924901185771,0.25034578146611347]],"notes":"A MOVE to stop Mr. Gaitskiell from","imageWidth":1240,"imageHeight":1771},{"label":["line"],"shape":"rectangle","points":[[0.1324110671936759,0.24481327800829875],[0.9407114624505929,0.24481327800829875],[0.9407114624505929,0.3029045643153527],[0.1324110671936759,0.3029045643153527]],"notes":"nominating any more Labour life Peers","imageWidth":1240,"imageHeight":1771},{"label":["line"],"shape":"rectangle","points":[[0.13636363636363635,0.29737206085753803],[0.9802371541501976,0.29737206085753803],[0.9802371541501976,0.35408022130013833],[0.13636363636363635,0.35408022130013833]],"notes":"","imageWidth":1240,"imageHeight":1771},{"label":["line"],"shape":"rectangle","points":[[0.1422924901185771,0.34439834024896265],[0.958498023715415,0.34439834024896265],[0.958498023715415,0.40110650069156295],[0.1422924901185771,0.40110650069156295]],"notes":"","imageWidth":1240,"imageHeight":1771},{"label":["line"],"shape":"rectangle","points":[[0.1284584980237155,0.39695712309820197],[0.9268774703557316,0.39695712309820197],[0.9268774703557316,0.4591977869986169],[0.1284584980237155,0.4591977869986169]],"notes":"","imageWidth":1240,"imageHeight":1771}],"extras":null,"metadata":{"first_done_at":1551319523000,"last_updated_at":1551319736000,"sec_taken":0,"last_updated_by":"69FI7aSdl6aSMhn3Anp3BRvA8gg2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/84ae6d5f-1283-46b6-88ec-2e8c87343d1e___Vv_%28Name%29_001.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11607142857142858,0.2962194559704933],[0.8526785714285714,0.2962194559704933],[0.8526785714285714,0.3515444905486399],[0.11607142857142858,0.3515444905486399]],"notes":"Mathematicians seek and use patterns to formulate","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.12797619047619047,0.34923928077455046],[0.7901785714285714,0.34923928077455046],[0.7901785714285714,0.39073305670816044],[0.12797619047619047,0.39073305670816044]],"notes":"new conjectures by mathematical proof. When","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.12648809523809523,0.3849700322729368],[0.7455357142857143,0.3849700322729368],[0.7455357142857143,0.4287690179806362],[0.12648809523809523,0.4287690179806362]],"notes":"mathematical structures are good models of","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.13392857142857142,0.4287690179806362],[0.8318452380952381,0.4287690179806362],[0.8318452380952381,0.46680497925311204],[0.13392857142857142,0.46680497925311204]],"notes":"real phenomenona, then mathematical reasoning can","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.11607142857142858,0.46680497925311204],[0.8735119047619048,0.46680497925311204],[0.8735119047619048,0.5094513600737667],[0.11607142857142858,0.5094513600737667]],"notes":"provide insight or predictions about nature.","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.8655753968253969,0.5599738742892271],[0.11656746031746032,0.5599738742892271],[0.11656746031746032,0.5071461502996774],[0.8655753968253969,0.5071461502996774]],"notes":"Through the use of abstraction and logic,","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.8258928571428572,0.6012755494083295],[0.11408730158730158,0.6012755494083295],[0.11408730158730158,0.5522898417089289],[0.8258928571428572,0.5522898417089289]],"notes":"mathematics developed from counting, calculation,","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.8444940476190477,0.6396957123098203],[0.11780753968253968,0.6396957123098203],[0.11780753968253968,0.5916705086829569],[0.8444940476190477,0.5916705086829569]],"notes":"measurement, and the systematic study of","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.8544146825396826,0.6896419240817582],[0.13020833333333334,0.6896419240817582],[0.13020833333333334,0.6320116797295221],[0.8544146825396826,0.6320116797295221]],"notes":"the shapes and motions of physical objects","imageWidth":2536,"imageHeight":3274}],"extras":null,"metadata":{"first_done_at":1551411138000,"last_updated_at":1551411138000,"sec_taken":0,"last_updated_by":"69FI7aSdl6aSMhn3Anp3BRvA8gg2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/0f7a387a-6820-44e4-a0bb-43ea9f72023b___Vv_%28Name%29_001.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11934156378600823,0.29936305732484075],[0.8518518518518519,0.29936305732484075],[0.8518518518518519,0.3487261146496815],[0.11934156378600823,0.3487261146496815]],"notes":"","imageWidth":2536,"imageHeight":3274},{"label":["line"],"shape":"rectangle","points":[[0.11934156378600823,0.39171974522292996],[0.8004115226337448,0.39171974522292996],[0.8004115226337448,0.5222929936305732],[0.11934156378600823,0.5222929936305732]],"notes":"","imageWidth":2536,"imageHeight":3274}],"extras":null,"metadata":{"first_done_at":1551642111000,"last_updated_at":1551642111000,"sec_taken":77,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/88918018-1bc8-49c0-ad45-d751aa0e863b___page_0103.jpg","annotation":[],"extras":null,"metadata":{"first_done_at":1551642137000,"last_updated_at":1551642216000,"sec_taken":13,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/153188d0-e001-4485-9c23-4a271b3f6ac5___page_0063.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.06804123711340206,0.3057324840764331],[0.865979381443299,0.3057324840764331],[0.865979381443299,0.3328025477707006],[0.06804123711340206,0.3328025477707006]],"notes":"","imageWidth":1276,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642322000,"last_updated_at":1551642322000,"sec_taken":28,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/efc70808-e131-4eb1-a146-6e1575b4feb9___page_0088.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.14766839378238342,0.268],[0.9145077720207254,0.268],[0.9145077720207254,0.308],[0.14766839378238342,0.308]],"notes":"Mathematical analysis is the branch of mathematics","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.07512953367875648,0.3],[0.8989637305699482,0.3],[0.8989637305699482,0.344],[0.07512953367875648,0.344]],"notes":"dealing with limits and related theories, such as","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.08549222797927461,0.35],[0.9093264248704663,0.35],[0.9093264248704663,0.382],[0.08549222797927461,0.382]],"notes":"differentiation, integration, measure, infinite series, and","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.09585492227979274,0.388],[0.4378238341968912,0.388],[0.8782383419689119,0.388],[0.8782383419689119,0.412],[0.2772020725388601,0.41],[0.08549222797927461,0.41]],"notes":"analytic functions. These theories are usually studied in","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.11658031088082901,0.42],[0.4326424870466321,0.422],[0.8264248704663213,0.424],[0.9145077720207254,0.418],[0.9015544041450777,0.448],[0.47668393782383417,0.454],[0.21761658031088082,0.45],[0.10621761658031088,0.444]],"notes":"the context of real and complex numbers and functions,","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.11139896373056994,0.458],[0.44559585492227977,0.464],[0.8808290155440415,0.46],[0.8626943005181347,0.484],[0.33419689119170987,0.498],[0.10621761658031088,0.5]],"notes":"Analytics evolved from solutions, which involves the","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.11398963730569948,0.512],[0.5751295336787565,0.508],[0.8601036269430051,0.502],[0.8523316062176166,0.544],[0.3134715025906736,0.556],[0.10362694300518134,0.542]],"notes":"elementary concepts and techniques of analysis","imageWidth":1275,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642388000,"last_updated_at":1551642743000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/fb971620-c18e-495f-ade8-ee265da4c82c___page_0076.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11637251143436657,0.3042066195892488],[0.8869472493105777,0.3042066195892488],[0.8869472493105777,0.3341418454530844],[0.11637251143436657,0.3341418454530844]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8785600412792719,0.33495090561156643],[0.11532411043045336,0.33495090561156643],[0.11532411043045336,0.3681223721093303],[0.8785600412792719,0.3681223721093303]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8565436201970945,0.37135861274325843],[0.12056611545001941,0.37135861274325843],[0.12056611545001941,0.39967571829013],[0.8565436201970945,0.39967571829013]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8303335950992642,0.4045300792410223],[0.12056611545001941,0.4045300792410223],[0.12056611545001941,0.44174684653119634],[0.8303335950992642,0.44174684653119634]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8722696352557927,0.4377015457387861],[0.11532411043045336,0.4377015457387861],[0.11532411043045336,0.47896361382137037],[0.8722696352557927,0.47896361382137037]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8596888232088341,0.48219985445529856],[0.11637251143436657,0.48219985445529856],[0.11637251143436657,0.5226528623794008],[0.8596888232088341,0.5226528623794008]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8397692041344831,0.5234619225378828],[0.11532411043045336,0.5234619225378828],[0.11532411043045336,0.563914930461985],[0.8397692041344831,0.563914930461985]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.8722696352557927,0.5582515093526107],[0.11742091243827978,0.5582515093526107],[0.11742091243827978,0.6132676001293897],[0.8722696352557927,0.6132676001293897]],"notes":"","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.22959981985699351,0.6116494798124257],[0.11637251143436657,0.6116494798124257],[0.11637251143436657,0.6432028259932254],[0.22959981985699351,0.6432028259932254]],"notes":"","imageWidth":1273,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642111000,"last_updated_at":1551642325000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/3bbb25d4-f032-49ad-b5b6-1c652054d661___page_0062.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.04492958942654768,0.31433367149243563],[0.28419577480801256,0.30927543129985063],[0.5804946028443109,0.30927543129985063],[0.7272358898231037,0.31505627361318417],[0.926312693574225,0.3331131394393547],[0.9094909427542906,0.3547830158692679],[0.29728769060058546,0.34322130589525557],[0.05894590277463495,0.34249870377450703]],"notes":"Deep learning (also known as deep structured learning or","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.9403397415230607,0.3778898079587543],[0.03931876368652401,0.3778898079587543],[0.03931876368652401,0.34033918599418533],[0.9403397415230607,0.34033918599418533]],"notes":"hierarchical learning) is a part of a broader family","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.9506155016442616,0.40822264660307133],[0.06549186477122727,0.40822264660307133],[0.06549186477122727,0.37211725422734465],[0.9506155016442616,0.37211725422734465]],"notes":"of machine learnig methods based on learning data","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.8356581455905015,0.42916989556489065],[0.06455672441435275,0.42916989556489065],[0.06455672441435275,0.3995596843886675],[0.8356581455905015,0.3995596843886675]],"notes":"representations, as opposed to task-specific","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.9057504141292395,0.4645610250964833],[0.06736213728436459,0.4645610250964833],[0.06736213728436459,0.42195216293932924],[0.9057504141292395,0.42195216293932924]],"notes":"","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.9141666691404984,0.49561644051420345],[0.06549186477122727,0.49561644051420345],[0.06549186477122727,0.4587801574358044],[0.9141666691404984,0.4587801574358044]],"notes":"belief networks and recurrent neural networks","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.9506155016442616,0.5259493045058657],[0.05239994077804262,0.5259493045058657],[0.05239994077804262,0.4847856719375541],[0.9506155016442616,0.4847856719375541]],"notes":"have been applied to field including computer","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.9149477172695168,0.5467789079888803],[0.07686220154246657,0.5467789079888803],[0.07686220154246657,0.5130780492825777],[0.9149477172695168,0.5130780492825777]],"notes":"vision, speech recognition, natural language","imageWidth":1275,"imageHeight":1649},{"label":["line"],"shape":"rectangle","points":[[0.6789419844017848,0.5732606315165131],[0.07140723662889453,0.5732606315165131],[0.07140723662889453,0.5479832448567945],[0.6789419844017848,0.5479832448567945]],"notes":"processing, and audio recognition","imageWidth":1275,"imageHeight":1649}],"extras":null,"metadata":{"first_done_at":1551642500000,"last_updated_at":1551642500000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/b85e999f-7d09-4a86-ba70-00ea5fdd95f8___page_0102.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.15785041889328175,0.3840635656715723],[0.8367832614937538,0.3840635656715723],[0.8367832614937538,0.34882008552759275],[0.15785041889328175,0.34882008552759275]],"notes":"(i.e. , when a low-probability event occurs), the event","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.1607844415492907,0.41343313245822194],[0.8426513068057717,0.41343313245822194],[0.8426513068057717,0.38044884975936927],[0.1607844415492907,0.38044884975936927]],"notes":"carries more \"information\" (\"surprisal\") than when","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.16313165967409787,0.4536468469814807],[0.8197659300889019,0.4536468469814807],[0.8197659300889019,0.41162577450212046],[0.16313165967409787,0.41162577450212046]],"notes":"the source data has a higher-probability value.","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.17134692311092292,0.4870829691693588],[0.8027485986840499,0.4870829691693588],[0.8027485986840499,0.44867661260220154],[0.17134692311092292,0.44867661260220154]],"notes":"The amount of information conveyed by each","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.17134692311092292,0.5146451779999069],[0.8074430349336642,0.5146451779999069],[0.8074430349336642,0.48392009274618114],[0.17134692311092292,0.48392009274618114]],"notes":"event defined in this way becomes a random","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.17193372764212472,0.5476294606987596],[0.7428945365014672,0.5476294606987596],[0.7428945365014672,0.5160006964669831],[0.17193372764212472,0.5160006964669831]],"notes":"variable whose expected value is the","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.8056826213400589,0.5788063854415108],[0.179562186547748,0.5788063854415108],[0.179562186547748,0.5498886581438864],[0.8056826213400589,0.5498886581438864]],"notes":"information entropy. Generally entropy refers","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15639462707566004,0.3481336872710597],[0.81831013827128,0.3481336872710597],[0.15639462707566004,0.3174086020173338],[0.81831013827128,0.3174086020173338]],"notes":"When the data source has a lower-probability value","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.18132260014135337,0.5950726070464244],[0.8772727741466774,0.585132138287866],[0.875512360553072,0.611338828651338],[0.1860170363909677,0.6221829763879472],[0.18366981826616055,0.6090796312062111]],"notes":"to disorder or uncertainty, and the definition","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.17956218654774803,0.6267013712782009],[0.7945333352472249,0.620827457920871],[0.7945333352472249,0.6483896667514192],[0.17838857748534445,0.6574264565319268],[0.17838857748534445,0.6569746170429014],[0.17838857748534445,0.6569746170429014]],"notes":"of entropy used in information theory is","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.16334450764181302,0.6652999622788471],[0.8425915778808454,0.6561940603143804],[0.8462871549768141,0.6857882416988971],[0.1603880459650381,0.7005853323911555]],"notes":"directly analogous to the definition used","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.16999654641455653,0.7045691645006097],[0.6762906085622575,0.6966015002817013],[0.6733341468854825,0.7313177515212306],[0.16777920015697537,0.7392854157401388]],"notes":"in statistical thermodynamics","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642302000,"last_updated_at":1551642481000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/886f2e7c-c81e-46ea-91a5-8736740da295___page_0114.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642515000,"last_updated_at":1551642515000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/fa6c1cc0-217a-49a9-a000-540f36ba204e___page_0048.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.1289134438305709,0.3252840909090909],[0.9465930018416207,0.30113636363636365],[0.9465930018416207,0.34801136363636365],[0.13259668508287292,0.3678977272727273]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.1141804788213628,0.3678977272727273],[0.9097605893186004,0.34801136363636365],[0.9097605893186004,0.37926136363636365],[0.1141804788213628,0.4090909090909091]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.1270718232044199,0.40198863636363635],[0.9521178637200737,0.37642045454545453],[0.9484346224677717,0.4119318181818182],[0.1252302025782689,0.4460227272727273]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.10681399631675875,0.4460227272727273],[0.9337016574585635,0.40625],[0.9318600368324125,0.4446022727272727],[0.12154696132596685,0.48295454545454547]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.13812154696132597,0.4758522727272727],[0.9226519337016574,0.44886363636363635],[0.9152854511970534,0.4815340909090909],[0.13627992633517497,0.5198863636363636]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.11602209944751381,0.515625],[0.9871086556169429,0.48579545454545453],[0.9723756906077348,0.5284090909090909],[0.12154696132596685,0.5639204545454546]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.10313075506445672,0.5582386363636364],[0.9042357274401474,0.5184659090909091],[0.9042357274401474,0.5696022727272727],[0.1252302025782689,0.5965909090909091]],"notes":"","imageWidth":1277,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.1141804788213628,0.5923295454545454],[0.8710865561694291,0.5610795454545454],[0.8710865561694291,0.6036931818181818],[0.12154696132596685,0.6448863636363636]],"notes":"","imageWidth":1277,"imageHeight":1655}],"extras":null,"metadata":{"first_done_at":1551642527000,"last_updated_at":1551642527000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/35db7729-5dcc-4b03-a996-0ceeb1090d88___page_0074.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.05502392344497608,0.3044280442804428],[0.9497607655502392,0.2933579335793358],[0.9425837320574163,0.3247232472324723],[0.04784688995215311,0.3376383763837638]],"notes":"when the data source has a lower probability value (i.e. when","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.05263157894736842,0.34317343173431736],[0.9856459330143541,0.32287822878228783],[0.9760765550239234,0.34870848708487084],[0.050239234449760764,0.3671586715867159]],"notes":"a low probability event occurs), the event carries more informaton ' ","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.05502392344497608,0.37084870848708484],[0.9976076555023924,0.35424354243542433],[0.992822966507177,0.3837638376383764],[0.045454545454545456,0.3966789667896679]],"notes":"(\" surprisal \") than when the source data has a higher-probability","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.04784688995215311,0.4003690036900369],[0.9210526315789473,0.37822878228782286],[0.916267942583732,0.4003690036900369],[0.050239234449760764,0.4261992619926199]],"notes":"value. The amount of information conveyed by each event ","imageWidth":1276,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642470000,"last_updated_at":1551642609000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/7c569746-270b-4736-9a28-2b38fb77bc11___page_0060.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.12424559496895629,0.29239130016027703],[0.8647493409839357,0.29239130016027703],[0.8647493409839357,0.335672907091897],[0.12424559496895629,0.335672907091897]],"notes":"From the technology perspective, speech recognition has a","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.8759314445311419,0.3683745656624543],[0.12424559496895629,0.3683745656624543],[0.12424559496895629,0.327016585705573],[0.8759314445311419,0.327016585705573]],"notes":"long history with several waves of major innovations. Most","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.11554840332112935,0.41646524003092095],[0.4646785251838965,0.41069435910670493],[0.8548096933864193,0.39338171633405694],[0.8448700457889028,0.36068005776349965],[0.45101150973731136,0.3654891252003463],[0.11679085927081892,0.378954514023517]],"notes":"recently, the field has benefited from advances in deep","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.1205182271198876,0.41646524003092095],[0.4162227431460036,0.4068471051572276],[0.9107202111224496,0.4039616646951196],[0.9069928432733809,0.4510905255762169],[0.7827472483044247,0.4433960176772622],[0.3516150337621463,0.4568614065004329],[0.20997505549753614,0.4664795413741262],[0.11927577117019804,0.4520523390635862]],"notes":"learning and big data. The advances are evidenced not only","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.11430594737143979,0.4674413548614956],[0.5889241201528528,0.4510905255762169],[0.829960574392628,0.44243420418989293],[0.9132051230218288,0.4501287120888476],[0.9082352992230704,0.48475399763414356],[0.4187076550453827,0.49341031902046756],[0.34043293021494025,0.5030284538941608],[0.22736943879319002,0.5049520808688995],[0.11554840332112935,0.5020666404067915]],"notes":"by the surge of academic papers published in the field,","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.11585660358533005,0.5070778522812383],[0.4117365450494037,0.4960394228438236],[0.7798040625935677,0.48569089524624726],[0.9179407822529997,0.4898303062852778],[0.9117023497522512,0.5229455945975219],[0.7628711743772503,0.5429527479528361],[0.5320491718495541,0.5374335332341287],[0.43847268433832604,0.5491618645113818],[0.3056831925366785,0.5484719626715434],[0.11229178501347374,0.539503238753644]],"notes":"but more importantly by the worldwide industry adaption","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.11229178501347374,0.5512315700308971],[0.25844934645958245,0.5360537295544519],[0.4046069079056911,0.5484719626715434],[0.4402550936242542,0.5381234350739671],[0.5115514650613804,0.5470921589918666],[0.5979983154288959,0.5339840240349366],[0.7619799697342862,0.5360537295544519],[0.9714130608308443,0.5277749074763909],[0.9580449911863832,0.5802074473041108],[0.6345377057904231,0.5850367601829797],[0.3769795639738047,0.5857266620228181],[0.2040858632387737,0.5843468583431413],[0.1176390128712582,0.5843468583431413]],"notes":"of a variety of deep learning methods in designing and deploying","imageWidth":1278,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.1176390128712582,0.590555974901687],[0.35559065254266686,0.5836569565033028],[0.5008570093458115,0.590555974901687],[0.4919449629161707,0.6264308705732848],[0.12387744537200675,0.625051066893608]],"notes":"speech recognition systems","imageWidth":1278,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642108000,"last_updated_at":1551642449000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/af2350a1-ec02-4282-8302-4e272bac49dc___page_0075.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.14876456634926738,0.28535026568872723],[0.8697005417341785,0.27358324442321275],[0.8747865098144955,0.3206513294852708],[0.14622158230910898,0.3196707443798113]],"notes":"Gradient descent is a first-order iterative optimization","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.1538505344295842,0.32555425501256857],[0.8671575576940201,0.32555425501256857],[0.8633430816337826,0.35006888264905717],[0.152579042409505,0.35203005285997624]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.15639351846974264,0.358894148598193],[0.8582571135534657,0.3559523932818144],[0.8506281614329905,0.37458351028554576],[0.1513075503894258,0.3883117017619793]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.16275097857013868,0.3941952123947366],[0.8747865098144955,0.3785058507073839],[0.8722435257743371,0.4089039889766298],[0.16147948655005948,0.42459335066398246]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.16275097857013868,0.42851569108582066],[0.8722435257743371,0.4167486698203061],[0.8798724778948123,0.44028271235133515],[0.17165142271069314,0.45793324424960696]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.16020799452998027,0.46675851019874287],[0.8900444140554459,0.447146808089552],[0.9078453023365548,0.447146808089552],[0.9192887305172677,0.4461662229840924],[0.9205602225373469,0.4795061165697169],[0.16910843867053474,0.48833138251885283]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.16147948655005948,0.4961760633625292],[0.8875014300152875,0.48440904209701463],[0.8798724778948123,0.5206906909990178],[0.16020799452998027,0.5177489356826391]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.16020799452998027,0.5216712761044773],[0.8824154619349707,0.5246130314208559],[0.8684290497140994,0.5746228717992927],[0.1640224705902179,0.5520694143737231]],"notes":"maximum of that function the procedure is then known as gradient ascent","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.14876456634926738,0.5579529250064804],[0.48189547561002,0.5657976058501568],[0.4768095075297032,0.5961957441194027],[0.15512202644966344,0.5834481377484285]],"notes":"as gradient ascent","imageWidth":1274,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642588000,"last_updated_at":1551642759000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/727cb1dd-56ae-4960-9308-d3ace1b098a2___page_0101.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.03322180454046187,0.28531308855693593],[0.034198916438710744,0.316259746310466],[0.956592548385652,0.3117309671270226],[0.9546383245891543,0.27776512325119684]],"notes":"Information is the resolution of uncertainty; it is that which answers the question of \"what an","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.035176028336959625,0.32003372896333554],[0.03713025213345738,0.3441872179417005],[0.973203450655883,0.35701875896145685],[0.9751576744523808,0.32003372896333554]],"notes":"entity is\" and is thus that which specifies the nature of that entity, ad well as essentiality","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.03908447592995514,0.3494707936557178],[0.04201581162470178,0.3743790791646566],[0.9761347863506296,0.3947585854901521],[0.9722263387576341,0.35928314855317856]],"notes":"of its properties. Information is associated with data and knowledge, as data is meaningful","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.041038699726452896,0.3804174514092479],[0.041038699726452896,0.40834492304048237],[0.9438900937084166,0.43023402242712555],[0.9438900937084166,0.3977777716124477]],"notes":"information and represents the values attributed to parameters, and knowledge signifies","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.04201581162470178,0.4151380918156475],[0.5882213627458248,0.42872442936597777],[0.5852900270510782,0.46269027324180345],[0.04299292352295065,0.4528779183443427]],"notes":"understanding of an abstract or concrete concept.","imageWidth":1276,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642360000,"last_updated_at":1551642533000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/cd013df3-cfcf-4737-bb2c-2f8ff1bca5c0___page_0105.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551641885000,"last_updated_at":1551641885000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/d5990a8f-acd1-4ef6-b445-1c1df96f05c2___page_0111.jpg","annotation":[],"extras":null,"metadata":{"first_done_at":1551642236000,"last_updated_at":1551642236000,"sec_taken":19,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/e19b8034-88ff-49de-bdc9-a1af0a82af0d___page_0071.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.0509868540276039,0.2874260130672446],[0.896441597176236,0.2874260130672446],[0.896441597176236,0.33318537833168155],[0.0509868540276039,0.33318537833168155]],"notes":"From the technology perspective, speech recognition has along history with several waves of","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.02688397757819115,0.33175539816716787],[0.983582765877959,0.33175539816716787],[0.983582765877959,0.36535993203323874],[0.02688397757819115,0.36535993203323874]],"notes":"major innovations. Most recently, the field has benefitted from advances in deep learning and big","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.017613640482263167,0.36535993203323874],[0.9492825186230255,0.36535993203323874],[0.9492825186230255,0.3946745254057687],[0.017613640482263167,0.3946745254057687]],"notes":"ta. The advances are evidenced not only by the urge of academic papers published in the","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.024102876449412756,0.3961045055702823],[0.9298148107215767,0.3961045055702823],[0.9298148107215767,0.4268490791073259],[0.024102876449412756,0.4268490791073259]],"notes":"field, but more importantly by the worldwide industry adoption of a variety of deep","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.02781101128778395,0.4625985832201673],[0.8130085633128841,0.4625985832201673],[0.8130085633128841,0.4318540096831237],[0.02781101128778395,0.4318540096831237]],"notes":"learning methods in designing and deploying speech recognition systems.","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642401000,"last_updated_at":1551642401000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/632a93ac-bef9-4642-a4f6-7c9e96af9e9d___page_0059.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.07921326940063682,0.31591498942434393],[0.866793476659842,0.31591498942434393],[0.866793476659842,0.35944105463392023],[0.07921326940063682,0.35944105463392023]],"notes":"","imageWidth":1275,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551641914000,"last_updated_at":1551641914000,"sec_taken":55,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/31d87003-20b1-42cc-939e-dd03589e1694___page_0058.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1621105995278845,0.3695700146550101],[0.8231108609831319,0.3695700146550101],[0.8231108609831319,0.3176542268820444],[0.1621105995278845,0.3176542268820444]],"notes":"Computer science is teh study of process that","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.15868572770687286,0.4047671589078682],[0.8162611173411086,0.4047671589078682],[0.8162611173411086,0.37220980047397445],[0.15868572770687286,0.37220980047397445]],"notes":"interact with data and that can be represented","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.16096897558754727,0.44876358922394083],[0.8002783821763876,0.44876358922394083],[0.8002783821763876,0.4074069447268326],[0.16096897558754727,0.4074069447268326]],"notes":"as data in the form of programs. It enables the","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.15982735164721007,0.4575628752871554],[0.7751626554889689,0.4575628752871554],[0.7751626554889689,0.49803959117794216],[0.15982735164721007,0.49803959117794216]],"notes":"use of algorithms to manipulate, store, and","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1689603431699078,0.5490754503445865],[0.7888621427730155,0.5490754503445865],[0.7888621427730155,0.5041990914221923],[0.1689603431699078,0.5041990914221923]],"notes":"communicate digital information. A computer","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1689603431699078,0.556114879195158],[0.8105529976394226,0.556114879195158],[0.8105529976394226,0.5913120234480161],[0.1689603431699078,0.5913120234480161]],"notes":"scientist studies the theory of computation","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.8826095814097336,0.5976818707166152],[0.17526104545136137,0.5976818707166152],[0.17526104545136137,0.6491894140466649],[0.8826095814097336,0.6491894140466649]],"notes":"and the practice of designing software systems.","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.16265233714550803,0.6472457331662856],[0.8220877815416375,0.6472457331662856],[0.8220877815416375,0.6948659147355769],[0.16265233714550803,0.6948659147355769]],"notes":"Its fields can be devided into theoretical and ","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1689566912984347,0.6997251169365251],[0.7968703649299308,0.6997251169365251],[0.7968703649299308,0.7385987345441097],[0.1689566912984347,0.7385987345441097]],"notes":"practical disciplines. Computational","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1538262413314107,0.7356832132235409],[0.8296530065251495,0.7356832132235409],[0.8296530065251495,0.7823315543526426],[0.1538262413314107,0.7823315543526426]],"notes":"complexity theory is highly abstract, while","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.9216965771578789,0.7658102668694191],[0.1651740788066787,0.7658102668694191],[0.1651740788066787,0.819261491079848],[0.9216965771578789,0.819261491079848]],"notes":"computer graphics emphasizes real-world applications.","imageWidth":1274,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642451000,"last_updated_at":1551642645000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/3dabc6d4-9f8a-4fc1-a7b3-3d9288ceb74f___page_0064.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.12568189620300696,0.32634637601373423],[0.9641124406099086,0.3097740991067868],[0.9674198589310404,0.2804539168868029],[0.1223744778818752,0.29192703166853573]],"notes":"Natural Language processing (NLP) is a subfield of","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.09756884047338699,0.3556665582337182],[0.1041836771156505,0.3301707476076452],[0.9723809864127381,0.3072245180441795],[0.9707272772521722,0.3378194907954671],[0.9707272772521722,0.33654470026416344]],"notes":"computer science, information engineering, and artificial","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.09756884047338699,0.38116236885979116],[0.8946566558661416,0.3531169771711109],[0.8781195642604829,0.3301707476076452],[0.09922254963395287,0.35056739610850357]],"notes":"intelligence concerned with the interactions between","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.0876465855099917,0.40793297001716783],[0.09426142215225522,0.37733799726588024],[0.9773421138944357,0.35056739610850357],[0.9740346955733039,0.3747884162032729],[0.9740346955733039,0.3735136256719693],[0.9740346955733039,0.3735136256719693]],"notes":"computes and human (natural) languages, in particular","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.0876465855099917,0.4359783617058481],[0.9773421138944357,0.4155817132049897],[0.9756884047338699,0.38243715939109485],[0.08930029467055757,0.4053833889545605]],"notes":"how to program computers to process and analyze large","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08268545802829405,0.469122915519743],[0.08268545802829405,0.43470357117454445],[0.9012714925084052,0.40920776054847147],[0.8979640741872734,0.4334287806432408]],"notes":"amounts of natural language data. Challenges in","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08103174886772817,0.5022674693336379],[0.08103174886772817,0.46784812498843936],[0.9376530940408545,0.43087919958063353],[0.9376530940408545,0.4601993818006174]],"notes":"","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.07441691222546465,0.5315876515536219],[0.07607062138603053,0.5009926788023342],[0.9277308390774592,0.4589245912693138],[0.9211160024351958,0.4920691450832087]],"notes":"recognition, natural language understanding","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.07441691222546465,0.5660069958988203],[0.07607062138603053,0.5303128610223181],[0.5490314413078725,0.5073666314588525],[0.5490314413078725,0.53796160421014]],"notes":"natural language generation","imageWidth":1275,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551597360000,"last_updated_at":1551597360000,"sec_taken":0,"last_updated_by":"69FI7aSdl6aSMhn3Anp3BRvA8gg2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/74143f70-aeb1-4b5c-b97c-1d7b74e57efc___page_0070.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.15682656826568267,0.32670454545454547],[0.8302583025830258,0.3039772727272727],[0.8357933579335793,0.3352272727272727],[0.15867158671586715,0.3565340909090909]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15498154981549817,0.3778409090909091],[0.8782287822878229,0.32954545454545453],[0.9095940959409594,0.3678977272727273],[0.12546125461254612,0.3991477272727273]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.13284132841328414,0.41051136363636365],[0.9280442804428044,0.3650568181818182],[0.9317343173431735,0.40198863636363635],[0.13099630996309963,0.4431818181818182]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.13284132841328414,0.4502840909090909],[0.8690036900369004,0.4161931818181818],[0.8634686346863468,0.4431818181818182],[0.13468634686346864,0.48863636363636365]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.13468634686346864,0.4900568181818182],[0.8837638376383764,0.45454545454545453],[0.8837638376383764,0.4900568181818182],[0.13284132841328414,0.5369318181818182]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.12915129151291513,0.5326704545454546],[0.8726937269372693,0.4971590909090909],[0.8653136531365314,0.53125],[0.12361623616236163,0.5724431818181818]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.14206642066420663,0.5639204545454546],[0.8800738007380073,0.5284090909090909],[0.8745387453874539,0.5610795454545454],[0.14206642066420663,0.6079545454545454]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15313653136531366,0.5980113636363636],[0.940959409594096,0.5525568181818182],[0.9501845018450185,0.5894886363636364],[0.14944649446494465,0.6477272727272727]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.14391143911439114,0.640625],[0.9391143911439115,0.6051136363636364],[0.9391143911439115,0.6448863636363636],[0.14022140221402213,0.6846590909090909]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15867158671586715,0.6775568181818182],[0.9760147601476015,0.6477272727272727],[0.977859778597786,0.7017045454545454],[0.15129151291512916,0.7400568181818182]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.16974169741697417,0.71875],[0.988929889298893,0.6832386363636364],[0.9907749077490775,0.7357954545454546],[0.1752767527675277,0.7585227272727273]],"notes":"","imageWidth":1273,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.18265682656826568,0.7613636363636364],[0.9446494464944649,0.7244318181818182],[0.9428044280442804,0.7713068181818182],[0.1863468634686347,0.8110795454545454]],"notes":"","imageWidth":1273,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642063000,"last_updated_at":1551642376000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/a6a99b01-c4c7-44e9-864b-ca8aab68d290___page_0110.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.08617234468937876,0.30124223602484473],[0.905811623246493,0.30124223602484473],[0.905811623246493,0.2748447204968944],[0.08617234468937876,0.2748447204968944]],"notes":"","imageWidth":1280,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551641579000,"last_updated_at":1551641579000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/0d30b30b-ad40-4163-b137-676d77c4f67f___page_0112.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551641900000,"last_updated_at":1551641900000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/44744051-aa5b-45f4-8682-fb06a8acf900___page_0066.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.11244019138755981,0.3044280442804428],[0.9760765550239234,0.27490774907749077],[0.9545454545454546,0.33025830258302585],[0.1076555023923445,0.34501845018450183]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.11004784688995216,0.35424354243542433],[0.12200956937799043,0.3985239852398524],[0.8923444976076556,0.3892988929889299],[0.9019138755980861,0.34501845018450183]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.12200956937799043,0.4151291512915129],[0.9688995215311005,0.3874538745387454],[0.9784688995215312,0.45202952029520294],[0.1291866028708134,0.45018450184501846]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.13157894736842105,0.48154981549815495],[0.1291866028708134,0.5295202952029521],[0.9641148325358851,0.5166051660516605],[0.9473684210526315,0.466789667896679]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.1339712918660287,0.5369003690036901],[0.1339712918660287,0.5793357933579336],[0.937799043062201,0.5627306273062731],[0.9354066985645934,0.518450184501845]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.1339712918660287,0.5885608856088561],[0.9712918660287081,0.5756457564575646],[0.9784688995215312,0.6273062730627307],[0.13636363636363635,0.6199261992619927]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.10047846889952153,0.6439114391143912]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.4880382775119617,0.6143911439114391],[0.7464114832535885,0.6328413284132841],[0.9473684210526315,0.6476014760147601],[0.930622009569378,0.6808118081180812],[0.5287081339712919,0.6697416974169742],[0.39952153110047844,0.6642066420664207],[0.20334928229665072,0.6642066420664207],[0.10526315789473684,0.6863468634686347],[0.11004784688995216,0.6457564575645757]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.12200956937799043,0.6992619926199262],[0.23205741626794257,0.6678966789667896],[0.8803827751196173,0.7084870848708487],[0.8444976076555024,0.7472324723247232],[0.13157894736842105,0.7361623616236163],[0.11483253588516747,0.7453874538745388]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.1076555023923445,0.7361623616236163],[0.12679425837320574,0.7693726937269373],[0.9210526315789473,0.7749077490774908],[0.9282296650717703,0.7472324723247232]],"notes":"","imageWidth":1275,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.11244019138755981,0.7767527675276753],[0.11244019138755981,0.8191881918819188],[0.6076555023923444,0.8118081180811808],[0.6124401913875598,0.7767527675276753]],"notes":"","imageWidth":1275,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642247000,"last_updated_at":1551642247000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/9b412681-c5a5-43a6-87a4-d4138f644e00___page_0072.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.8928571428571429,0.2875722543352601],[0.16165413533834586,0.2875722543352601],[0.16165413533834586,0.3208092485549133],[0.8928571428571429,0.3208092485549133]],"notes":"","imageWidth":1271,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642383000,"last_updated_at":1551642383000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/1d2e05c3-a70b-4e51-89b5-c621481a0c88___page_0099.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642454000,"last_updated_at":1551642486000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/5713d627-e625-4aef-97ee-48fc4ae0d10a___page_0098.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1618798955613577,0.2796780684104628],[0.9086161879895561,0.2796780684104628],[0.9086161879895561,0.32595573440643866],[0.1618798955613577,0.32595573440643866]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.057441253263707574,0.317907444668008],[0.9451697127937336,0.317907444668008],[0.9451697127937336,0.3722334004024145],[0.057441253263707574,0.3722334004024145]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.07571801566579635,0.3762575452716298],[0.8877284595300261,0.3762575452716298],[0.8877284595300261,0.43259557344064387],[0.07571801566579635,0.43259557344064387]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.056323637429408505,0.43094301319767464],[0.9595133947795663,0.43094301319767464],[0.9595133947795663,0.4805479643571192],[0.056323637429408505,0.4805479643571192]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.0623583128682737,0.48984889269951504],[0.9776174210961619,0.48984889269951504],[0.9776174210961619,0.5394538438589596],[0.0623583128682737,0.5394538438589596]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.07241610526638237,0.5487547722013555],[0.9957214474127576,0.5487547722013555],[0.9957214474127576,0.6138612705981265],[0.07241610526638237,0.6138612705981265]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.9414093684629707,0.6727671500999669],[0.07241610526638237,0.6727671500999669],[0.07241610526638237,0.6138612705981265],[0.9414093684629707,0.6138612705981265]],"notes":"signifies understanding of an abstract or","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.10460104094033008,0.685168387889828],[0.3962770204854813,0.685168387889828],[0.3962770204854813,0.7394238032204705],[0.10460104094033008,0.7394238032204705]],"notes":"concrete concept.","imageWidth":1273,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551641700000,"last_updated_at":1551641700000,"sec_taken":187,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/0da2b94d-2d6b-4b6b-a4b8-091819b49bc9___page_0107.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.14661673952495569,0.31196126201482305],[0.8659880802157455,0.31196126201482305],[0.8659880802157455,0.3517688188865062],[0.14661673952495569,0.3517688188865062]],"notes":"Gradient descent isa first-order iterative optimization algorithm","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.14767153621218557,0.3623300074442997],[0.8554401133434465,0.3623300074442997],[0.8554401133434465,0.40701195903496445],[0.14767153621218557,0.40701195903496445]],"notes":"for finding the minimum of a function. To find a local minimum","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.16454828320786394,0.40701195903496445],[0.8744264537135846,0.40701195903496445],[0.8744264537135846,0.4500691123859687],[0.16454828320786394,0.4500691123859687]],"notes":"of a function using gradient descent, on takes steps proportional","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15821950308448454,0.4646922965429135],[0.532672327051098,0.44519471766698704],[0.8733716570263547,0.44275752030749627],[0.8744264537135846,0.4720038886213859],[0.4282474550153382,0.4915014674973124],[0.16032909645894433,0.5012502569352756]],"notes":"to the negative of the gradient (or approximate gradient) of the","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.14978112958664538,0.5069370507740875],[0.8680976735902052,0.4776906824601978],[0.8744264537135846,0.5118114454930691],[0.7425768678098474,0.5280594278896744],[0.40504192789628046,0.5337462217284863],[0.14978112958664538,0.538620616447468]],"notes":"function at the current point. If, instead, one takes steps proportional","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.14872633289941548,0.5475570067656009],[0.6202204520911795,0.5313090243689955],[0.8723168603391248,0.5256222305301836],[0.8723168603391248,0.5621801909225457],[0.33753493991356703,0.5833025680381327],[0.14978112958664538,0.5833025680381327]],"notes":"to the positive of the gradient, one approaches a local maximum of","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15189072296110517,0.5881769627571143],[0.4662201357556145,0.5768033750794905],[0.8860292172731135,0.562992590042376],[0.8870840139603434,0.5963009539554169],[0.39765835108567116,0.6239225240296461],[0.16032909645894433,0.6271721205089672]],"notes":"that function; the procedure is then known as gradient ascent.","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642270000,"last_updated_at":1551642479000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/50fa8599-62ac-42a2-8141-fe510a164fdb___page_0113.jpg","annotation":[],"extras":null,"metadata":{"first_done_at":1551642233000,"last_updated_at":1551642233000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/45152f96-4d00-43a7-9049-2d7fbab41b73___page_0028.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.12658430181286598,0.2715296999168771],[0.8368274162604412,0.2715296999168771],[0.8368274162604412,0.3047739984993239],[0.12658430181286598,0.3047739984993239]],"notes":"Mathematical analysis is the branch of mathematics","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.12489361458636344,0.3323770429267258],[0.8124467565658244,0.3323770429267258],[0.8124467565658244,0.3020286872738697],[0.12489361458636344,0.3020286872738697]],"notes":"dealing with limits and related theories, such as","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.12489361458636344,0.3612295056952805],[0.8932181013391373,0.3612295056952805],[0.8932181013391373,0.3263114553983094],[0.12489361458636344,0.3263114553983094]],"notes":"differentiation, integration, measure, infinite series, and","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.12489361458636344,0.39157786134813654],[0.8755053256420379,0.39157786134813654],[0.8755053256420379,0.3536270329209625],[0.12489361458636344,0.3536270329209625]],"notes":"analytic functions. These theories are usually studied","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.12686170415675385,0.4219466975477875],[0.8262499545482879,0.4219466975477875],[0.8262499545482879,0.38702867851882683],[0.12686170415675385,0.38702867851882683]],"notes":"in the context of real and complex numbers","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.12292552501597303,0.45229508446865396],[0.9089893584555768,0.45229508446865396],[0.9089893584555768,0.40979507321217023],[0.12292552501597303,0.40979507321217023]],"notes":"and functions. Analysis evolved from calculus, which","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.12095744559105406,0.4841803253674116],[0.8951861604731134,0.4841803253674116],[0.8951861604731134,0.4386270241659196],[0.12095744559105406,0.4386270241659196]],"notes":"involved the elementary concepts and techniques of","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.11898935602066364,0.5251639319247887],[0.2588563675576068,0.5251639319247887],[0.2588563675576068,0.4917622863269243],[0.11898935602066364,0.4917622863269243]],"notes":"analysis","imageWidth":1274,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551641986000,"last_updated_at":1551642217000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/730580a1-7c34-42d8-805e-95e93ca6aab4___page_0000.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11998585392590445,0.33858742103311024],[0.9481808944388548,0.33858742103311024],[0.9481808944388548,0.37094133015405184],[0.11998585392590445,0.37094133015405184]],"notes":"","imageWidth":1280,"imageHeight":1658},{"label":["line"],"shape":"rectangle","points":[[0.12388783291536476,0.37545582910116],[0.9579358419125055,0.37545582910116],[0.9579358419125055,0.39953315681907003],[0.12388783291536476,0.39953315681907003]],"notes":"","imageWidth":1280,"imageHeight":1658},{"label":["line"],"shape":"rectangle","points":[[0.114132885441714,0.4446781462901514],[0.9442789154493945,0.4446781462901514],[0.9442789154493945,0.4732699729551696],[0.114132885441714,0.4732699729551696]],"notes":"","imageWidth":1280,"imageHeight":1658},{"label":["line"],"shape":"rectangle","points":[[0.11705936968380923,0.4807941378670165],[0.8584353776812677,0.4807941378670165],[0.8584353776812677,0.5048714655849266],[0.11705936968380923,0.5048714655849266]],"notes":"","imageWidth":1280,"imageHeight":1658}],"extras":null,"metadata":{"first_done_at":1551642485000,"last_updated_at":1551642485000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/84428a31-96f5-47e9-b2b0-4a1e4e33fa0f___page_0014.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.14787980517008467,0.2951591907825168],[0.9830751827973561,0.2506339600562004],[0.9933058611424563,0.2815143620115491],[0.1581104835151851,0.3296303371512783],[0.15625036017971208,0.3310666349166432]],"notes":"gradient descent is a first-order iterative optimization","imageWidth":1279,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.15424697803388596,0.3357447043694809],[0.9826596273020838,0.2885145624704102],[0.9813299119742055,0.32239705557191745],[0.15956583934539925,0.3706539396861853]],"notes":"algorithm for finding the minimum of a function","imageWidth":1279,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.1622252700011559,0.3727074241165797],[0.9946270652529887,0.325477282217509],[0.9919676345972321,0.3521725798126359],[0.16754413131266924,0.4076166594332841]],"notes":"To find a local minimum of a function using gradient","imageWidth":1279,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.1755224232799392,0.41377711272446727],[0.9959567805808671,0.36141325974941063],[0.9946270652529887,0.39632249506611505],[0.17286299262418253,0.45279331690196045]],"notes":"descent, one takes steps proportional to the negative","imageWidth":1279,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.17052120851774863,0.45530151144511877],[0.9593810049316325,0.4023880925474428],[0.9530063803141465,0.4429960651898453],[0.17689583313523455,0.48852621633435717]],"notes":"of the gradient of approximate gradient","imageWidth":1279,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.181676801598349,0.4946789394619939],[0.8892601341392872,0.4589931453217008],[0.8908537902936586,0.4872956717088298],[0.1896450823702064,0.530364733602287]],"notes":"of the function at the current point.","imageWidth":1279,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.1450227100478049,0.5352869121043964],[0.9577873487772609,0.4959094840875213],[0.9466317556966606,0.5365174567299238],[0.15777195928277676,0.5882003310020724]],"notes":"If instead one takes steps proportional to","imageWidth":1279,"imageHeight":1655}],"extras":null,"metadata":{"first_done_at":1551642487000,"last_updated_at":1551642487000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/de461a9b-7af6-4049-8b8a-3a49b42eba27___page_0015.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.13551052137336192,0.3129159013602186],[0.9736681906086005,0.3129159013602186],[0.9736681906086005,0.3438976737721214],[0.13551052137336192,0.3438976737721214]],"notes":"When the data source has a lower-probability value (i.e., when a low-probability event","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.13049161317434851,0.3562903827368825],[0.9425509597747175,0.3562903827368825],[0.9425509597747175,0.39037033238997565],[0.13049161317434851,0.39037033238997565]],"notes":"occurs), the event carries more \"information\" (\"surprise\") than when the source","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12848404989474316,0.41747938325039063],[0.9204677636990585,0.41747938325039063],[0.9204677636990585,0.4445884341108056],[0.12848404989474316,0.4445884341108056]],"notes":"data has a higher-probability value. The amount of information conveyed by","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1274802682549405,0.4593047760064594],[0.9214715453388611,0.4593047760064594],[0.9214715453388611,0.49106109272865983],[0.1274802682549405,0.49106109272865983]],"notes":"each event defined in this way becomes a random variable whose expected","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1264764866151378,0.5104247004860991],[0.978687098807614,0.5104247004860991],[0.978687098807614,0.5630937135863339],[0.1264764866151378,0.5630937135863339]],"notes":"value is the information entrophy. Generally, entropy refers to a disorder or uncertainty","imageWidth":1276,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642217000,"last_updated_at":1551642409000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/b8c470a9-1856-4205-8e0f-74f528b2727d___page_0001.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.13889398232459746,0.2773116101391781],[0.1420269292943252,0.3119755614065754],[0.8709592575843178,0.3039141773909016],[0.8709592575843178,0.26844408772193695],[0.8699149419277419,0.27328091813134125]],"notes":"Data is measured, collected and reported, and","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.1409826136377493,0.3160062534144123],[0.1409826136377493,0.3490579278786748],[0.894978517685564,0.34663951267397264],[0.8855796767763807,0.31036328460344065]],"notes":"analyzed, where upon it can be visualized using","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13784966666802154,0.35067020468180954],[0.13993829798117338,0.38049732553980253],[0.8709592575843178,0.3748543567288309],[0.8636490479882863,0.340996543863001]],"notes":"graphs, images or other analysis tools. Data as","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13784966666802154,0.39097712476017843],[0.13889398232459746,0.4095183079962282],[0.9210864090999621,0.4159674152087672],[0.9095989368776269,0.3788850487366678]],"notes":"a general concept refers to the fact that some","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13889398232459746,0.41516127680719983],[0.13784966666802154,0.4393454288542212],[0.922130724756538,0.44901908967302967],[0.9189977777868102,0.4167735536103346]],"notes":"existing information or knowledge is represented or","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.14411556060747704,0.4506313664761645],[0.1462041919206289,0.4861014561451291],[0.9252636717262657,0.49416284016080286],[0.9231750404131138,0.4554681968855687]],"notes":"coded in some form suitable for better usage or","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.1409826136377493,0.48771373294826387],[0.13784966666802154,0.519959269010959],[0.9169091464736584,0.5280206530266328],[0.9210864090999621,0.503030362578044]],"notes":"processing. Raw data (\"unprocessed data\") is a","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.14411556060747704,0.5288267914282001],[0.14411556060747704,0.5634907426955974],[0.9377954596051767,0.5868687563410514],[0.9304852500091453,0.5376943138454413]],"notes":"collection of number of characters before it has","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13889398232459746,0.5683275731050017],[0.1409826136377493,0.6150836003959096],[0.9106432525342029,0.6239511228131508],[0.9148205151605066,0.592511725152023]],"notes":"been \"cleaned\" and corrected by researchers.","imageWidth":1276,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642056000,"last_updated_at":1551642542000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/7f7ecd81-03b4-41fb-8a58-3c0259be2f91___page_0029.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.6018972108873774,0.4227590737457191],[0.5801987220936318,0.33689750627240617],[0.6320864126873712,0.3587267183418925],[0.6320864126873712,0.35945435874420867]],"notes":"","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.7507991593488054,0.3565437971349439]],"notes":"","imageWidth":1274,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642405000,"last_updated_at":1551642405000,"sec_taken":55,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/b872b837-a64b-40dc-bcef-3b2daa9def90___page_0017.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.0751233291352395,0.29338367799802967],[0.9215128373922713,0.29338367799802967],[0.9215128373922713,0.3461412692169736],[0.0751233291352395,0.3461412692169736]],"notes":"Mathematical analysis is the branch of mathematics","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.08180095839170524,0.3795973026728892],[0.8764388399111276,0.3795973026728892],[0.8764388399111276,0.3461412692169736],[0.08180095839170524,0.3461412692169736]],"notes":"declining with limits and related theories, such as","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.08180095839170524,0.389891466813171],[0.9248516520205041,0.389891466813171],[0.9248516520205041,0.4362152054444388],[0.08180095839170524,0.4362152054444388]],"notes":"differentiation integration, measure, infinite series, and","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.07679273644935594,0.44522259906718537],[0.8113319546605866,0.44522259906718537],[0.8113319546605866,0.4825389440757067],[0.07679273644935594,0.4825389440757067]],"notes":"analytic function. These theories are usually","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.08680918033405453,0.48768602614584755],[0.9448845397899013,0.48768602614584755],[0.9448845397899013,0.5262891416719041],[0.08680918033405453,0.5262891416719041]],"notes":"studied on the context of real and complex numbers","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.9565703909887163,0.5352965352946506],[0.0801315510775888,0.5352965352946506],[0.0801315510775888,0.5764731918557776],[0.9565703909887163,0.5764731918557776]],"notes":"and functions analysis evolved from calculus, which","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.8680918033405454,0.5816202739259184],[0.07846214376347237,0.5816202739259184],[0.07846214376347237,0.6215101599695102],[0.8680918033405454,0.6215101599695102]],"notes":"involves the elementary concepts and techniques","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.3188567969962388,0.6652603575657076],[0.07679273644935594,0.6652603575657076],[0.07679273644935594,0.6227969304870454],[0.3188567969962388,0.6227969304870454]],"notes":"of analysis.","imageWidth":1274,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642137000,"last_updated_at":1551642256000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/8458cec1-b8d7-45ed-aa8b-285188d70131___page_0003.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642214000,"last_updated_at":1551642758000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/47496a4a-8671-4531-a6ae-952f6312e47e___page_0002.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551641931000,"last_updated_at":1551641931000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/7fd48af1-f073-4dfe-a486-2bfd7c6d40a5___page_0016.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1519707592927914,0.29922677536138453],[0.7747205011773822,0.29922677536138453],[0.7747205011773822,0.3580491329110584],[0.1519707592927914,0.3580491329110584]],"notes":"Natural language processing (NLP) is a subfield","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15527447145133033,0.5869448285934851],[0.822624327476197,0.5869448285934851],[0.822624327476197,0.6137985135618145],[0.15527447145133033,0.6137985135618145]],"notes":"understanding, and natural image generation","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.16023003968913876,0.3644428674273273],[0.8242761835554664,0.3644428674273273],[0.8242761835554664,0.39385404620216424],[0.16023003968913876,0.39385404620216424]],"notes":"of computer science, information engineering, and","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.14371147889644403,0.3976902869119256],[0.8077576227627717,0.3976902869119256],[0.8077576227627717,0.424543971880255],[0.14371147889644403,0.424543971880255]],"notes":"artificial intelligence concerned with the instructions","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15031890321352193,0.45267640375183815],[0.8440984565067001,0.45267640375183815],[0.8440984565067001,0.47697259491365995],[0.15031890321352193,0.47697259491365995]],"notes":"in particular how to program computers to process","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15527447145133033,0.5140562551080196],[0.787935349811538,0.5140562551080196],[0.787935349811538,0.5511399153023792],[0.15527447145133033,0.5511399153023792]],"notes":"data. challenge in natural language processing","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.13875591065863563,0.5562549029153944],[0.8738318659335506,0.5562549029153944],[0.8738318659335506,0.5805510940772162],[0.13875591065863563,0.5805510940772162]],"notes":"frequently involve speech recognition, natural language","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.13710405457936617,0.4795300887201675],[0.8655725855372032,0.4795300887201675],[0.8655725855372032,0.5076625205917507],[0.13710405457936617,0.5076625205917507]],"notes":"and analyze large amount of natural language","imageWidth":1281,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.13710405457936617,0.41942898426723985],[0.8870467145677063,0.41942898426723985],[0.8870467145677063,0.45779139136485325],[0.13710405457936617,0.45779139136485325]],"notes":"between computers and human (natural) images,","imageWidth":1281,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642070000,"last_updated_at":1551642703000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/23c1d93f-c104-4fbd-bb1b-3c8e1e19a046___page_0012.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.0832148135216203,0.30503882421183975],[0.8097785616888856,0.30503882421183975],[0.8097785616888856,0.4561779701448554],[0.0832148135216203,0.4561779701448554]],"notes":"","imageWidth":1278,"imageHeight":1656}],"extras":null,"metadata":{"first_done_at":1551642350000,"last_updated_at":1551642350000,"sec_taken":115,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/3fa22029-8053-4435-aad8-5dbbd0486aec___page_0007.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.13433246428237078,0.2807275109597344],[0.8911163472196872,0.2807275109597344],[0.8911163472196872,0.3228880528987421],[0.13433246428237078,0.3228880528987421]],"notes":"Mathematicians seek and use patterns to formulae new","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.12236224469285258,0.3249446647006449],[0.897766469213864,0.3249446647006449],[0.897766469213864,0.3599070653329928],[0.12236224469285258,0.3599070653329928]],"notes":"conjectures; they resolve the truth or falsify of conjectures","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.11571212269867581,0.36196367713489563],[0.8192950296825782,0.36196367713489563],[0.8192950296825782,0.3958977718662921],[0.11571212269867581,0.3958977718662921]],"notes":"by mathematical proof. When mathematical","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.1330024398835354,0.3979543836681949],[0.8498855908557913,0.3979543836681949],[0.8498855908557913,0.4339450902014942],[0.1330024398835354,0.4339450902014942]],"notes":"structures are good models of real phenomena, then","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.1396525618777122,0.4380583138052998],[0.920376883994065,0.4380583138052998],[0.920376883994065,0.4689074908338421],[0.1396525618777122,0.4689074908338421]],"notes":"mathematic reasoning can provide insight on predictions","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.12635231788935863,0.46376596132908504],[0.8884562984220166,0.46376596132908504],[0.8884562984220166,0.5079831150699956],[0.12635231788935863,0.5079831150699956]],"notes":"about nature. Through the use of abstraction and logic,","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.1449726594730536,0.5048981973671414],[0.7940245661047064,0.5048981973671414],[0.7940245661047064,0.5388322920985378],[0.1449726594730536,0.5388322920985378]],"notes":"mathematics, developed from counting, calculation,","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.1303423910858647,0.5388322920985378],[0.9123967376010529,0.5388322920985378],[0.9123967376010529,0.5727663868299343],[0.1303423910858647,0.5727663868299343]],"notes":"measurement, and the systematic study of the shapes","imageWidth":1280,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.1423126106753829,0.5614550219194688],[0.5838807110887204,0.5614550219194688],[0.5838807110887204,0.5974457284527681],[0.1423126106753829,0.5974457284527681]],"notes":"and motions of physical objects.","imageWidth":1280,"imageHeight":1656}],"extras":null,"metadata":{"first_done_at":1551642538000,"last_updated_at":1551642538000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/a9509c7d-fcf0-4c19-b31d-957ffddb37a9___page_0013.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.13739097105283643,0.2936915170956773],[0.8563516775281906,0.28406228702696656],[0.8610354606322647,0.32619016857757605],[0.13739097105283643,0.3400321868013477]],"notes":"In the simplest case, an optimization problem consists of","imageWidth":1277,"imageHeight":1657},{"label":["line"],"shape":"polygon","points":[[0.13661034053549076,0.3430413211978198],[0.8906994202913998,0.3243846879396928],[0.8985057254648564,0.37012353076606874],[0.1319265574314168,0.3743363189211297]],"notes":"maximizing or minimizing a real function by systematically","imageWidth":1277,"imageHeight":1657},{"label":["line"],"shape":"polygon","points":[[0.1319265574314168,0.37734545331760183],[0.8977250949475107,0.37313266516254084],[0.8906994202913998,0.4002148747307898],[0.13270718794876246,0.4086404510409117]],"notes":"choosing input values from within an allowed set and","imageWidth":1277,"imageHeight":1657},{"label":["line"],"shape":"polygon","points":[[0.1319265574314168,0.4122514123166782],[0.8883575287393628,0.40442766288585075],[0.8906994202913998,0.4459537175571658],[0.1280234048446885,0.4561847745051709]],"notes":"computing the value of the function. The generalization","imageWidth":1277,"imageHeight":1657},{"label":["line"],"shape":"polygon","points":[[0.12926569857121709,0.4568707277224312],[0.9292021319372552,0.44716384257252123],[0.9224870307127764,0.4834028804655186],[0.12171120969367841,0.4937568912920893]],"notes":"of optimization theory and techniques to other formulations","imageWidth":1277,"imageHeight":1657},{"label":["line"],"shape":"polygon","points":[[0.11835365908143901,0.4885798858788039],[0.828475613570073,0.48728563452548257],[0.8242786753047738,0.5312901805384079],[0.1301050862242769,0.5267603008017833]],"notes":"constitutes a large area of applied mathematics.","imageWidth":1277,"imageHeight":1657}],"extras":null,"metadata":{"first_done_at":1551642399000,"last_updated_at":1551642421000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/ae4e63e9-61a8-4335-b38d-8a6a6bc5741f___page_0005.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.11409090059286668,0.2833943336948642],[0.9192466847768116,0.2758705903224342],[0.9268527448163361,0.31766916461371214],[0.12061038062674478,0.32937276541526994]],"notes":"From the technology perspective, speech recognition","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.11517748059851303,0.33689650878769994],[0.8638311044888477,0.3277008224436188],[0.8605713644719087,0.36197565336246673],[0.11517748059851303,0.37953105456480346]],"notes":"has a long history with several waves of major","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.11300432058722033,0.3845468834797568],[0.9072943047147017,0.36615551079159453],[0.9040345646977627,0.40126631319626793],[0.10865800056463494,0.41464185696947686]],"notes":"innovations. Most recently, the field has benefited","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.10865800056463494,0.4221656003419069],[0.8540518844380306,0.4037742276537446],[0.8475324044041526,0.43136128668598805],[0.11517748059851303,0.4539325168032781]],"notes":"from advances in deeplearning and big data.","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.10974458057028129,0.4606202886898826],[0.10322510053640319,0.49907497703785825],[0.8714371645283722,0.47566777543474265],[0.8627445244832014,0.4464087734308481]],"notes":"The advances are evidenced not only by the","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.10539826054769588,0.5049267774386371],[0.9040345646977627,0.480683604349696],[0.9138137847485798,0.5199742641834972],[0.10648484055334224,0.5333498079567062]],"notes":"surge of academic papers published in the field,","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.10757142055898859,0.5408735513291362],[0.9170735247655188,0.5216462071551483],[0.9181601047711652,0.5575929810456474],[0.10974458057028129,0.5692965818472051]],"notes":"but more importantly by the worldwide industry","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.11083116057592764,0.573476439276333],[0.10865800056463494,0.5977196123652742],[0.8258008042912255,0.6010634983085764],[0.8312337043194573,0.5651167244180774],[0.8312337043194573,0.560100895503124]],"notes":"adoption of a variety of deep learning","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.11191774058157399,0.6027354412802275],[0.11300432058722033,0.637846243684901],[0.8138484242291156,0.6370102721990754],[0.8051557841839448,0.6035714127660531]],"notes":"methods in designing and deploying","imageWidth":1274,"imageHeight":1655},{"label":["line"],"shape":"polygon","points":[[0.11191774058157399,0.647041930028982],[0.699757523636249,0.6436980440856799],[0.6954112036136636,0.6846606468911322],[0.11626406060415938,0.6804807894620044]],"notes":"speech recognition systems","imageWidth":1274,"imageHeight":1655}],"extras":null,"metadata":{"first_done_at":1551641986000,"last_updated_at":1551642498000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/440fa873-d9f6-4605-917f-eaba293624b1___page_0011.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551641921000,"last_updated_at":1551641921000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/90d345d6-bc45-4fc0-9e9e-9bb10d4ee06e___page_0038.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.1451990632318501,0.30685920577617326],[0.9344262295081968,0.2851985559566787],[0.9508196721311475,0.3194945848375451],[0.1522248243559719,0.351985559566787]],"notes":"In mathematical analysis, the maxima and minima (the respective","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.9812646370023419,0.3664259927797834],[0.9718969555035128,0.3285198555956679],[0.14285714285714285,0.3574007220216607],[0.1451990632318501,0.3862815884476534],[0.14754098360655737,0.388086642599278],[0.14754098360655737,0.3916967509025271],[0.14754098360655737,0.388086642599278],[0.9742388758782201,0.3628158844765343],[0.9742388758782201,0.36101083032490977],[0.9742388758782201,0.3592057761732852]],"notes":"plurals of maximum and minimum) of a ruction, known collectively","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.1405152224824356,0.3953068592057762],[0.8711943793911007,0.37184115523465705],[0.8711943793911007,0.36823104693140796],[0.8711943793911007,0.36823104693140796],[0.8711943793911007,0.3664259927797834],[0.8711943793911007,0.3628158844765343],[0.8594847775175644,0.3935018050541516],[0.14754098360655737,0.42057761732851984]],"notes":"as extrema (the plural of extremum), are the largest","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.14754098360655737,0.4296028880866426],[0.8875878220140515,0.3971119133574007],[0.9039812646370023,0.3953068592057762],[0.9016393442622951,0.4187725631768953],[0.9039812646370023,0.427797833935018],[0.14988290398126464,0.4602888086642599],[0.1358313817330211,0.4332129963898917]],"notes":"and smallest value of the function, either within a given","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.1451990632318501,0.4657039711191336],[0.9461358313817331,0.4332129963898917],[0.955503512880562,0.4296028880866426],[0.955503512880562,0.4657039711191336],[0.1451990632318501,0.4927797833935018],[0.1358313817330211,0.4657039711191336],[0.1358313817330211,0.5018050541516246]],"notes":"range (the local or relative extrema) or on the entire domain","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13817330210772832,0.5],[0.9765807962529274,0.4675090252707581],[0.9929742388758782,0.5],[0.1405152224824356,0.5306859205776173]],"notes":"of a function (the global or absolute extrema). Pierre de Fermat","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.14285714285714285,0.5379061371841155],[0.9765807962529274,0.5054151624548736],[0.9695550351288056,0.5451263537906137],[0.1451990632318501,0.5740072202166066]],"notes":"was one of the first mathematicians to propose a general","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13817330210772832,0.5812274368231047],[0.9344262295081968,0.5523465703971119],[0.9344262295081968,0.5794223826714802],[0.1358313817330211,0.6137184115523465]],"notes":"technique, adequality for finding the maxima and minima","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13348946135831383,0.6209386281588448],[0.3255269320843091,0.6155234657039711],[0.32786885245901637,0.6534296028880866],[0.1358313817330211,0.6570397111913358],[0.12177985948477751,0.6155234657039711]],"notes":"of functions.","imageWidth":1273,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642026000,"last_updated_at":1551642404000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/689c78f1-025c-4c2f-b424-ce06d35cfe3e___page_0010.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.07932692307692307,0.29944547134935307],[0.8990384615384616,0.29944547134935307],[0.8990384615384616,0.3438077634011091],[0.07932692307692307,0.3438077634011091]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.0889423076923077,0.35489833641404805],[0.9038461538461539,0.35489833641404805],[0.9038461538461539,0.3955637707948244],[0.0889423076923077,0.3955637707948244]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.07211538461538461,0.3974121996303142],[0.9423076923076923,0.3974121996303142],[0.9423076923076923,0.4343807763401109],[0.07211538461538461,0.4343807763401109]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.08653846153846154,0.43807763401109057],[0.9086538461538461,0.43807763401109057],[0.9086538461538461,0.46950092421441775],[0.08653846153846154,0.46950092421441775]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.057692307692307696,0.47504621072088726],[0.9711538461538461,0.47504621072088726],[0.9711538461538461,0.512014787430684],[0.057692307692307696,0.512014787430684]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.08653846153846154,0.5175600739371534],[0.9615384615384616,0.5175600739371534],[0.9615384615384616,0.5508317929759704],[0.08653846153846154,0.5508317929759704]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.10096153846153846,0.5600739371534196],[0.9423076923076923,0.5600739371534196],[0.9423076923076923,0.5951940850277264],[0.10096153846153846,0.5951940850277264]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.08173076923076923,0.5970425138632163],[0.9783653846153846,0.5970425138632163],[0.9783653846153846,0.6377079482439926],[0.08173076923076923,0.6377079482439926]],"notes":"","imageWidth":1273,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.11057692307692307,0.6414048059149723],[0.8100961538461539,0.6414048059149723],[0.8100961538461539,0.6968576709796673],[0.11057692307692307,0.6968576709796673]],"notes":"","imageWidth":1273,"imageHeight":1656}],"extras":null,"metadata":{"first_done_at":1551642448000,"last_updated_at":1551642448000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/a1628e99-7b9a-4f84-953b-66c24eeabe1a___page_0004.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.06325661478110305,0.36367123323056577],[0.06472769884577985,0.3251515387450853],[0.9400227173284846,0.30702462369309447],[0.9402678980059309,0.3512089791323221]],"notes":"Computer science is the study of processes","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.0465843287147658,0.4116320293056248],[0.044132521940304446,0.36442652135773207],[0.9390419946187002,0.35026486897336423],[0.9451715115548536,0.38425283469584703]],"notes":"that interact with data and that can","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.03432529484245901,0.45883753725351756],[0.04045481177861241,0.41257613946458266],[0.913298023486856,0.3927498261264677],[0.9108462167123945,0.43429067312061337]],"notes":"be represented as data in the form","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.031873488067997655,0.4890490623401689],[0.038003005004151054,0.45694931693560187],[0.9770449996228513,0.42296135121311906],[0.985626323333466,0.4607257575714333]],"notes":"of programs. It enables the use of algorithms","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.02696987451907494,0.5305899093343146],[0.015936744033998827,0.4890490623401689],[0.9108462167123945,0.456005206776644],[0.9255570573591627,0.501322494406621]],"notes":"to manipulate, store, and communicate digital","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.015936744033998827,0.5655219852157553],[0.01348493725953747,0.5287016890163988],[0.924331153971932,0.4862167318632954],[0.9476233183293149,0.5305899093343146]],"notes":"information. A computer scientist studies the","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.017162647421229506,0.6042305017330273],[0.014710840646768149,0.5674102055336709],[0.9047166997762411,0.5249252483805674],[0.9145239268740866,0.5674102055336709]],"notes":"theory of computation and the practice","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.017162647421229506,0.6467154588861307],[0.014710840646768149,0.6004540610971958],[0.9108462167123945,0.5692984258515866],[0.9120721200996252,0.6155598236405215]],"notes":"of designing software systems. Its fields","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.008581323710614753,0.6844798652444449],[0.012259033872306792,0.6495477893630043],[0.821355269444555,0.6089510525278166],[0.8238070762190164,0.6457713487271729]],"notes":"can be divided into theoretical and","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.008581323710614753,0.7250766020796328],[0.003677710161692037,0.6844798652444449],[0.9660118691377751,0.6278332557069737],[0.981948613171774,0.6750387636548664]],"notes":"practical disciplines. Computational complexity","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.018388550808460188,0.7543440170073262],[0.011033130485076112,0.7184678309669277],[0.787029974602096,0.67787109413174],[0.7992890084744028,0.7222442716027592]],"notes":"is highly abstract, while computer","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.022066260970152223,0.8006054147962611],[0.018388550808460188,0.7543440170073262],[0.8201293660573243,0.7213001614438013],[0.8274847863807084,0.7618968982789891]],"notes":"graphics emphasizes real-world","imageWidth":1276,"imageHeight":1656},{"label":["line"],"shape":"polygon","points":[[0.02942168129353629,0.82326405861125],[0.028195777906305647,0.8006054147962619],[0.3371234314884368,0.7864437624118941],[0.3371234314884368,0.8223199484522927]],"notes":"applications","imageWidth":1276,"imageHeight":1656}],"extras":null,"metadata":{"first_done_at":1551641899000,"last_updated_at":1551642413000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"CORRECT"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/4b4847ca-35ad-49ca-9c7d-11217b51dcb4___page_0009.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.16563203811314875,0.29920697287551934],[0.17160076020731627,0.3446634168316079],[0.975139972134619,0.3078379432469286],[0.9788704234434737,0.2692862755879674]],"notes":"Gradient descent is a first-order iterative optimization","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.18414220221641592,0.3944108534188876],[0.1740980820955205,0.3485791503211118],[0.31304174376790705,0.34147846392568176],[0.8227808399033493,0.3143667522340397],[0.9625615115858105,0.3104936505638051],[0.9290811111828258,0.3356688114203299],[0.7415908689261114,0.35955293838677643]],"notes":"algorithm for finding the minimum of a function.","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.18246818219626668,0.3944108534188876],[0.18246818219626668,0.42281359900060783],[0.9525173914649151,0.3905377517486531],[0.9726056317067059,0.3614894892218937],[0.6670969780294704,0.3569708706066201]],"notes":"To find a local minimum of a function using gradient","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.18330519220634126,0.424104632890686],[0.18330519220634126,0.46412668348311004],[0.9466583213943929,0.42668670067084224],[0.935777191263423,0.3950563703639268]],"notes":"descent, one takes steps proportional to the negative of ","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.1903054702443,0.46283086228848935],[0.1986522013953658,0.4885794498984192],[0.28879689782687634,0.4885794498984192],[0.777080670164225,0.46283086228848935],[0.9598740823725658,0.45703743007625514],[0.9523620243366065,0.4293576983955805],[0.8730680784014816,0.4287139837053323]],"notes":"the gradient (or approximate gradient) of the function","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.19809454641782176,0.4914711490862539],[0.21771711941203994,0.5202964364226911],[0.5513008603137493,0.5130901145885818],[0.9119824401122361,0.4857060916189665],[0.9101136236365962,0.46048396519958396]],"notes":"at the current point. If, instead, one takes steps","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.21024185350948063,0.5606518386937032],[0.21024185350948063,0.5282233904402113],[0.5372847367464505,0.51164885022176],[0.9045071742096767,0.5030012640208288],[0.9063759906853165,0.5339884479074988]],"notes":"one approaches a local maximum of that function","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.21491389469858022,0.5743438501785108],[0.5400879614599102,0.5505629881259502],[0.8989007247827572,0.5375916088245534],[0.901703949496217,0.5671375283444015],[0.21491389469858022,0.6082135627988244]],"notes":"","imageWidth":1274,"imageHeight":1650},{"label":["line"],"shape":"polygon","points":[[0.21865152764985987,0.6161405168163447],[0.21491389469858022,0.6500102294366584],[0.9568340355275918,0.6089341949822354],[0.9512275861006724,0.5736232179950999]],"notes":"the procedure is then known as gradient ascent","imageWidth":1274,"imageHeight":1650}],"extras":null,"metadata":{"first_done_at":1551642278000,"last_updated_at":1551642496000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/d24d8d1c-4385-4dbc-98a3-b72e78b7a95f___page_0021.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642589000,"last_updated_at":1551642589000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/f63c79eb-1b21-4921-8ff6-8acb5415d7f6___page_0035.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.14758836868157094,0.32632492232155297],[0.8758748982504443,0.32632492232155297],[0.8758748982504443,0.3549778423302747],[0.14758836868157094,0.3549778423302747]],"notes":"Deep learning (also known as deep structured learning or","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.15241602560106157,0.3571002808494393],[0.866909249685676,0.3571002808494393],[0.866909249685676,0.38256954307941415],[0.15241602560106157,0.38256954307941415]],"notes":"hierarchical learning) is part of a broader family of","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1517263603268486,0.38628381048795213],[0.9807040199308125,0.38628381048795213],[0.9807040199308125,0.4053857571604333],[0.1517263603268486,0.4053857571604333]],"notes":"machine learning methods based on learning data representations,","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1517263603268486,0.40963063419876244],[0.8648402538630371,0.40963063419876244],[0.8648402538630371,0.42926319050103473],[0.1517263603268486,0.42926319050103473]],"notes":"as opposed to task-specific algorithms. Learning can be","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.14965736450420977,0.4324468482797816],[0.8896682037347032,0.4324468482797816],[0.8896682037347032,0.45261001421184505],[0.14965736450420977,0.45261001421184505]],"notes":"supervised, semi-supervised or unsupervised. Deep networks","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1517263603268486,0.45685489125017414],[0.8896682037347032,0.45685489125017414],[0.8896682037347032,0.47967110533119334],[0.1517263603268486,0.47967110533119334]],"notes":"have been applied to fields including computer vision, speech","imageWidth":1273,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1503470297784227,0.4844465919993136],[0.9613933922528499,0.4844465919993136],[0.9613933922528499,0.5003648808930479],[0.1503470297784227,0.5003648808930479]],"notes":"recognition natural language processing, and audio recognition.","imageWidth":1273,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642005000,"last_updated_at":1551642503000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/7f7effd3-09e4-4836-a476-6c724cb6a2a2___page_0034.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1456223957524815,0.3119468566362452],[0.9691665270711269,0.3119468566362452],[0.9691665270711269,0.34461144895417667],[0.1456223957524815,0.34461144895417667]],"notes":"From the technology perspective, speech recognition has a long history with","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1449154909187316,0.34842231805793533],[0.9727010512398765,0.34842231805793533],[0.9727010512398765,0.37346517216834946],[0.1449154909187316,0.37346517216834946]],"notes":"several waves of major innovations. Most recently, the field has benefited","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1458882093429722,0.3736080488203237],[0.9708358486847941,0.3736080488203237],[0.9708358486847941,0.3966074878273449],[0.1458882093429722,0.3966074878273449]],"notes":"from advances in deep learning and big data. The advances are","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.13962952138444584,0.39850513130977233],[0.9683717983074215,0.39850513130977233],[0.9683717983074215,0.42254194875385387],[0.13962952138444584,0.42254194875385387]],"notes":"evidence not only by the surge of academic papers published","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.9889055514521929,0.40293296610210316],[0.9913696018295656,0.40293296610210316],[0.9889055514521929,0.40293296610210316],[0.9913696018295656,0.40293296610210316]],"notes":"","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.14415026823142896,0.4244376027495618],[0.9574259676022351,0.4244376027495618],[0.9574259676022351,0.4569674920113896],[0.14415026823142896,0.4569674920113896]],"notes":"in the field, but more importantly by the worldwide industry","imageWidth":1273,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1421388691398276,0.4610982716001931],[0.2713195659387469,0.4610982716001931],[0.2713195659387469,0.4912539277338767],[0.1421388691398276,0.4912539277338767]],"notes":"adoption","imageWidth":1273,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642230000,"last_updated_at":1551642517000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/2de4e581-35d9-462d-995d-cd461d0de696___page_0020.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.10386473429951697,0.2705223880597015],[0.9065562376402465,0.2705223880597015],[0.9065562376402465,0.3298105715295672],[0.10386473429951697,0.3298105715295672]],"notes":"In the simplest case, an optimization problem consists of","imageWidth":1278,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.11036336806055175,0.37852105594008795],[0.9341470796553845,0.37852105594008795],[0.9341470796553845,0.33285497680522474],[0.11036336806055175,0.33285497680522474]],"notes":"Maximizing or minimizing a real function by systematically","imageWidth":1278,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.13029008729370695,0.4363647561775813],[0.8682356237303328,0.4363647561775813],[0.8682356237303328,0.3873160045142097],[0.13029008729370695,0.3873160045142097]],"notes":"Choosing input values from within or allowed set onl","imageWidth":1278,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1368593353925493,0.4752654902554277],[0.9503512249658623,0.4752654902554277],[0.9503512249658623,0.44228443310247095],[0.1368593353925493,0.44228443310247095]],"notes":"computing the value of the function. The generalizations","imageWidth":1278,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1379542100756897,0.48625917597307994],[0.9634897211635471,0.48625917597307994],[0.9634897211635471,0.5437646089577224],[0.1379542100756897,0.5437646089577224]],"notes":"of optimization theory and techniques to other formulations","imageWidth":1278,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.8441483807012441,0.5953503650174753],[0.13247983665998772,0.5953503650174753],[0.13247983665998772,0.5395362682970869],[0.8441483807012441,0.5395362682970869]],"notes":"constitutes a large area of applied mathematics.","imageWidth":1278,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642039000,"last_updated_at":1551642473000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/95971e2b-2a14-43ce-a243-59b92f9b39b8___page_0036.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.08163527840800211,0.25929895472498055],[0.08673748330850224,0.30906340058128995],[0.874177772952356,0.28549076833356446],[0.8690755680518558,0.25733456870433674]],"notes":"In the simplest case, an optimization problem consists of","imageWidth":1271,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08851196004094052,0.315787138950529],[0.08851196004094052,0.3385056381556031],[0.8792188030733424,0.31502985564369324],[0.8762684044053111,0.2945832063591266],[0.8762684044053111,0.2923113564386192],[0.875284938182634,0.289282223211276],[0.8654502759558628,0.28852493990444017],[0.8654502759558628,0.28852493990444017],[0.8644668097331857,0.2862530899839328],[0.8644668097331857,0.2862530899839328],[0.09047889248629475,0.31200072241635]],"notes":"maximizing or minimizing a real function by systematically","imageWidth":1271,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08654502759558629,0.40590385246398936],[0.08654502759558629,0.430136918282735],[0.8929873301908221,0.39605916947512393],[0.8949542626361763,0.3748552368837215],[0.8949542626361763,0.3680396871221993],[0.8949542626361763,0.3657678372016919],[0.8979046613042077,0.36046685405384127]],"notes":"optimization theory and techniques to other formulations","imageWidth":1271,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08949542626361763,0.4399816012716004],[0.09047889248629475,0.46042825055616704],[0.7946407079231104,0.4278650683622276],[0.7907068430324019,0.40741841907766096],[0.7867729781416934,0.40363200254348197],[0.7867729781416934,0.39681645278195976]],"notes":"constitutes a large area of applied mathematics.","imageWidth":1271,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08556156137290916,0.3385056381556031],[0.08556156137290916,0.3627387039743487],[0.8654502759558628,0.33471922162142403],[0.8585660123971229,0.3066997392684994]],"notes":"choosing input values from within an allowed set and","imageWidth":1271,"imageHeight":1652},{"label":["line"],"shape":"polygon","points":[[0.08654502759558629,0.3703115370427067],[0.08949542626361763,0.39833101939563137],[0.8969211950815306,0.3703115370427067],[0.8939707964134992,0.34304933799661785],[0.8910203977454678,0.33623378823509564]],"notes":"computing the value of the function. The generalization of","imageWidth":1271,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642285000,"last_updated_at":1551642285000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"CORRECT"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/6d0d7e7d-f506-406d-b4c2-b50447d1b1ae___page_0022.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1503638731818887,0.2813181677261853],[0.9385127654976947,0.2813181677261853],[0.9385127654976947,0.3249575510853719],[0.1503638731818887,0.3249575510853719]],"notes":"Data is measured, collected and reported, and","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.07064074579014906,0.3670383850388733],[0.9163113882493621,0.3670383850388733],[0.9163113882493621,0.32573682578821456],[0.07064074579014906,0.32573682578821456]],"notes":"analyzed, whereupon it can be visualized","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.052475982586967874,0.3647005609303455],[0.9667690638137543,0.3647005609303455],[0.9667690638137543,0.404443570775319],[0.052475982586967874,0.404443570775319]],"notes":"using graphs, images or other analysis tools","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.045411908007952966,0.44574513002597776],[0.9990619761749653,0.44574513002597776],[0.9990619761749653,0.4028850213696338],[0.045411908007952966,0.4028850213696338]],"notes":"Data as a general concepts refers to the fact","imageWidth":1275,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642374000,"last_updated_at":1551642473000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/bce05240-8bdf-4d63-91a9-84ce6e8dcaf4___page_0037.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[1,1],[0.7181069958847737,1],[0.7181069958847737,0.6449044585987261],[1,0.6449044585987261]],"notes":"","imageWidth":1276,"imageHeight":1649}],"extras":null,"metadata":{"first_done_at":1551642147000,"last_updated_at":1551642147000,"sec_taken":15,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/171942cf-00fd-4a5a-a3c2-0c3f153c57bb___page_0033.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.5012800867937639,0.30610368414063827],[0.6506844123365307,0.30610368414063827],[0.6506844123365307,0.34397218114772754],[0.5012800867937639,0.34397218114772754]],"notes":"","imageWidth":1277,"imageHeight":1650},{"label":["line"],"shape":"rectangle","points":[[0.5004636697142952,0.3244067910273981],[0.39188019814496206,0.3244067910273981],[0.39188019814496206,0.3427098979141579],[0.5004636697142952,0.3427098979141579]],"notes":"","imageWidth":1277,"imageHeight":1650}],"extras":null,"metadata":{"first_done_at":1551641913000,"last_updated_at":1551641913000,"sec_taken":270,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/5f5bb547-376b-4572-b7e9-45f95bead94f___page_0027.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.1660684104163008,0.2924327181842311],[0.9269439378792215,0.2924327181842311],[0.9269439378792215,0.34689726455092174],[0.1660684104163008,0.34689726455092174]],"notes":"","imageWidth":1279,"imageHeight":1657},{"label":["line"],"shape":"rectangle","points":[[0.09334564245622136,0.35360059333451443],[0.9768730024488282,0.35360059333451443],[0.9768730024488282,0.3996859787217142],[0.09334564245622136,0.3996859787217142]],"notes":"","imageWidth":1279,"imageHeight":1657},{"label":["line"],"shape":"rectangle","points":[[0.09877271469204818,0.4156063845827468],[0.9812146602374897,0.4156063845827468],[0.9812146602374897,0.4767742597330301],[0.09877271469204818,0.4767742597330301]],"notes":"","imageWidth":1279,"imageHeight":1657},{"label":["line"],"shape":"rectangle","points":[[0.10202895803354428,0.4918567494961137],[0.9573355423998516,0.4918567494961137],[0.9573355423998516,0.5563762890381934],[0.10202895803354428,0.5563762890381934]],"notes":"","imageWidth":1279,"imageHeight":1657},{"label":["line"],"shape":"rectangle","points":[[0.10202895803354428,0.5697829466053786],[0.970360515765836,0.5697829466053786],[0.970360515765836,0.6317887378536111],[0.10202895803354428,0.6317887378536111]],"notes":"","imageWidth":1279,"imageHeight":1657},{"label":["line"],"shape":"rectangle","points":[[0.09443105690338673,0.6426816471269492],[0.92260228009056,0.6426816471269492],[0.92260228009056,0.698822025689538],[0.09443105690338673,0.698822025689538]],"notes":"","imageWidth":1279,"imageHeight":1657},{"label":["line"],"shape":"rectangle","points":[[0.11722476029385938,0.7239595086280106],[0.46238655449244537,0.7239595086280106],[0.46238655449244537,0.7742344745049557],[0.11722476029385938,0.7742344745049557]],"notes":"","imageWidth":1279,"imageHeight":1657}],"extras":null,"metadata":{"first_done_at":1551641916000,"last_updated_at":1551642055000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/a2f9fc97-65c1-42ac-b6b9-a1c5ddb28356___page_0032.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.15132252227721377,0.30347144316314506],[0.8390869718689105,0.29407433107976894],[0.8412384769249845,0.31950181083478657],[0.5895123853643115,0.330557236815229],[0.5371590956665077,0.3327683220113175],[0.5228157286260134,0.3399543488986051],[0.47978562750453085,0.33332109331033966],[0.3894224151494174,0.33387386460936175],[0.2603321117849697,0.334979407207406],[0.15132252227721377,0.3382960350015387]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15275685898126318,0.3427182053937157],[0.15419119568531262,0.36704014255068906],[0.3815335632771456,0.35819580176633514],[0.5134925400496922,0.35487917397220237],[0.570866008211669,0.3593013443643794],[0.6784412610153755,0.35598471657024666],[0.7680873050184642,0.35156254617806965],[0.8512788338533305,0.35100977487904755],[0.8448243186851081,0.3244767525259857]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15849420579746087,0.39191485100668455],[0.687047281239672,0.38472882411939696],[0.8017942175636256,0.3769900259330873],[0.7996427125075514,0.3543264026731803],[0.15705986909341144,0.37090954164384393]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15490836403733732,0.39688979269788366],[0.15203969062923847,0.4250811289480119],[0.3449579773238854,0.42839775674214464],[0.8469758237411823,0.40407581958517125],[0.8455414870371328,0.37643725463406513]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15453302859568263,0.4319596895623016],[0.15648914288170393,0.46512936904003505],[0.5770537143762833,0.45834466187413503],[0.6063954286666028,0.46060623092943503],[0.6484518858160607,0.4485445293011683],[0.8636244572784035,0.4379905403764349],[0.8655805715644249,0.4100978553610682]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15844525716772523,0.4704063635024017],[0.1594233143107359,0.4990529048695352],[0.8489536001332438,0.4764372143165351],[0.8440633144181906,0.44703681659763495],[0.307109942905344,0.45834466187413503]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15844525716772523,0.5050837556836685],[0.15844525716772523,0.538253435161402],[0.3882886857752279,0.5344841534025686],[0.8264582858439989,0.5073453247389685],[0.8274363429870095,0.48397577783420176]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.1613794285967572,0.545038142327302],[0.16626971431181042,0.5797155345085688],[0.859712228706361,0.551822849493202],[0.856778057277329,0.5133761755531019]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.15453302859568263,0.5865002416744688],[0.16040137145374653,0.6219314902075023],[0.9037248001418402,0.5827309599156355],[0.9047028572848509,0.5510689931414353],[0.326671085765557,0.5736846836944355]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.1613794285967572,0.629470053725169],[0.17018194288385302,0.661132020499369],[0.6797497143924015,0.6392701862981357],[0.7746212572644344,0.6309777664287023],[0.787336000123573,0.633993191835769],[0.8254802287009882,0.6257007719663357],[0.8890539429966805,0.6166544957451356],[0.8880758858536698,0.5925310924886021],[0.8861197715676485,0.5872540980262355]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.16040137145374653,0.6679167276652691],[0.16431360002578912,0.7010864071430025],[0.3296052571945889,0.7025941198465359],[0.8734050287085101,0.6535934569817023],[0.8724269715654994,0.627208484669869],[0.379486171488132,0.6505780315746357]],"notes":"","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.1574672000247146,0.7078711143089026],[0.15257691430966133,0.7448100755454694],[0.45479657149995206,0.7289790921583693],[0.44990628578489883,0.7010864071430025],[0.33840777148168477,0.6995786944394692]],"notes":"","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642472000,"last_updated_at":1551642472000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/e63382fe-8bd9-440d-9707-15b18cf0f6ea___page_0030.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.15517241379310345,0.3521594684385382],[0.9116379310344828,0.3521594684385382],[0.9116379310344828,0.31893687707641194],[0.15517241379310345,0.31893687707641194]],"notes":"resolve the truth or falsity of conjectures by mathematical proog. When","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15517241379310345,0.3903654485049834],[0.9525862068965517,0.3903654485049834],[0.9525862068965517,0.3554817275747508],[0.15517241379310345,0.3554817275747508]],"notes":"mathematical structures are good models of real phenomena, then mathematical","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.16163793103448276,0.4318936877076412],[0.9267241379310345,0.4318936877076412],[0.9267241379310345,0.3953488372093023],[0.16163793103448276,0.3953488372093023]],"notes":"reasoning can provide insight of predictions about nature. Through the","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.14870689655172414,0.4435215946843854],[0.8728448275862069,0.4435215946843854],[0.8728448275862069,0.48172757475083056],[0.14870689655172414,0.48172757475083056]],"notes":"use of abstraction and logic, mathematics developed from counting,","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15086206896551724,0.5149501661129569],[0.8879310344827587,0.5149501661129569],[0.8879310344827587,0.48172757475083056],[0.15086206896551724,0.48172757475083056]],"notes":"calculation, measurement, and systematic study of the shapes","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.5581896551724138,0.5548172757475083],[0.14870689655172414,0.5548172757475083],[0.14870689655172414,0.5166112956810631],[0.5581896551724138,0.5166112956810631]],"notes":"and motions of physical objets.","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15948275862068967,0.27906976744186046],[0.8987068965517241,0.27906976744186046],[0.8987068965517241,0.3222591362126246],[0.15948275862068967,0.3222591362126246]],"notes":"Mathematicians seek and use patters to formulate new conjectures; they","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642117000,"last_updated_at":1551642383000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/d7ad6428-c7e8-4fa7-94fd-13fac1b3886a___page_0018.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.10691106390356535,0.3210322806307629],[0.9311207774857028,0.3210322806307629],[0.9311207774857028,0.35361466135149705],[0.10691106390356535,0.35361466135149705]],"notes":"mathematics seek and use patterns to formulate new","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12804464630310733,0.36319771450465416],[0.9062577393685946,0.36319771450465416],[0.9062577393685946,0.4082380643244925],[0.12804464630310733,0.4082380643244925]],"notes":"conjuctures; they resolve the truth or falsity of","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1355035577382398,0.4101546749551239],[0.9423091446384015,0.4101546749551239],[0.9423091446384015,0.44848688756775235],[0.1355035577382398,0.44848688756775235]],"notes":"conjuctures by mathematical proff when mathematical","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1355035577382398,0.45615333009027803],[0.9895489170609071,0.45615333009027803],[0.9895489170609071,0.49831876396416924],[0.1355035577382398,0.49831876396416924]],"notes":"structures are good models of real phenomenon, then","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.10815421580942075,0.49736045864885353],[0.9945215246843288,0.49736045864885353],[0.9945215246843288,0.5529421669371647],[0.10815421580942075,0.5529421669371647]],"notes":"mathematical reasoning can provide insight or predictions","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.14420562107922766,0.5510255563065333],[0.8950693722158959,0.5510255563065333],[0.8950693722158959,0.588399463603846],[0.14420562107922766,0.588399463603846]],"notes":"about nature through the use of abstractions","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.13674670964409522,0.5893577689191617],[0.9659290308496543,0.5893577689191617],[0.9659290308496543,0.6190652336939487],[0.13674670964409522,0.6190652336939487]],"notes":"and logic, mathematics developed from courtesy","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.11685627915040864,0.6238567602705273],[0.9522543598852448,0.6238567602705273],[0.9522543598852448,0.6669804994597343],[0.11685627915040864,0.6669804994597343]],"notes":"calcualtions, measurement, and the systematic","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12058573486797486,0.66793880477505],[0.8764220936280648,0.66793880477505],[0.8764220936280648,0.7120208492795727],[0.12058573486797486,0.7120208492795727]],"notes":"study of the shapes and motions of physical","imageWidth":1274,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642585000,"last_updated_at":1551642718000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/1e201b1a-ca44-4bd4-9beb-58bc5e353edc___page_0019.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.14812779202958973,0.2577482912953182],[0.9492823745310903,0.2577482912953182],[0.9492823745310903,0.30372501352637493],[0.14812779202958973,0.30372501352637493]],"notes":"Thus a neural network is either a biological neural network,","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.09393469738461788,0.2939723754773629],[0.9736692671213276,0.2939723754773629],[0.09393469738461788,0.3413423317154214],[0.9736692671213276,0.3413423317154214]],"notes":"made up of real biological neurons, or an artificial neural network,","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.0858057331878721,0.3406457147119205],[0.8815410062248755,0.3406457147119205],[0.8815410062248755,0.3810495009149704],[0.0858057331878721,0.3810495009149704]],"notes":"for solving artificial intelligence (AI) problems. The connections of","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.09393469738461788,0.37895964990446784],[0.9086375535473614,0.37895964990446784],[0.9086375535473614,0.42006005311101857],[0.09393469738461788,0.42006005311101857]],"notes":"the biological neuron are modeled as weights. A positive weight","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.09032182440828643,0.41657696809351424],[0.8661862960754668,0.41657696809351424],[0.8661862960754668,0.45698075429656415],[0.09032182440828643,0.45698075429656415]],"notes":"reflects an excitatory connection, while negative values mean","imageWidth":1275,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.08670895143195496,0.4562841372930633],[0.41819004701036616,0.4562841372930633],[0.41819004701036616,0.49390145548210973],[0.08670895143195496,0.49390145548210973]],"notes":"inhibitory connections.","imageWidth":1275,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642105000,"last_updated_at":1551642105000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/f7aa9b59-6baf-4f30-8d58-ef25153bf7bd___page_0042.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551641898000,"last_updated_at":1551641898000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/624156aa-4323-4af4-bb63-28f130d3a4df___page_0056.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.10052596062794851,0.3011990316098742],[0.9643045852829135,0.3011990316098742],[0.9643045852829135,0.32936458990927087],[0.10052596062794851,0.32936458990927087]],"notes":"From the technology perspective, speech recognition has a","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.10052596062794851,0.3379866995927596],[0.9829205039177188,0.31901805828908436],[0.9881329611354642,0.3489080385251787],[0.09605814015559525,0.37075071639001694]],"notes":"long history with several waves of major innovations. Most","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.10573841784569399,0.3730499456389473],[0.9322852052310484,0.3563805335842023],[0.934519115467225,0.38454609188359895],[0.10946160157265504,0.4017903112505765]],"notes":"recently, the field has benefitted from advances in deep","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.1139294220450083,0.40064069662611135],[0.9829205039177188,0.3799476333857383],[0.9896222346262487,0.412136842870763],[0.11095087506343947,0.4270818329888102]],"notes":"learning and big data. The advances are evidenced not","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.11169551180883168,0.43168029148667086],[0.9769634099545811,0.412136842870763],[0.9821758671723265,0.43915278654569445],[0.1131847852996161,0.4667435375328585]],"notes":"only by the surge of academic papers published in the","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.10797232808187063,0.4604206570983001],[0.9784526834453654,0.4391527865456945],[0.987388324390072,0.4719168033429518],[0.10946160157265505,0.4977831323934181]],"notes":"field, but more importantly by the worldwide industry","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.11020623831804725,0.49548390314448776],[0.9397315726849705,0.4747908399041147],[0.9397315726849705,0.5018067835790462],[0.11095087506343947,0.5219250395071866]],"notes":"adoption of a variety of deep learning methods in","imageWidth":1276,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.12286506298971485,0.5202006175704889],[0.9576028545743835,0.5041060128279765],[0.959092128065168,0.542043295435327],[0.12286506298971485,0.5512402124310485]],"notes":"designing and deploying speech recognition systems.","imageWidth":1276,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642454000,"last_updated_at":1551642751000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/9a824824-7b2a-4619-9ae4-110cceab66e8___page_0080.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.12686579645970855,0.2806875920128781],[0.8672055383182101,0.2806875920128781],[0.8672055383182101,0.30991551267106515],[0.12686579645970855,0.30991551267106515]],"notes":"In the simplest case, an optimization problem","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.1277125964808945,0.3249786622137598],[0.8466245634263271,0.3249786622137598],[0.8466245634263271,0.35303906879044994],[0.1277125964808945,0.35303906879044994]],"notes":"consists of maximizing or minimizing a real","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12263792377304437,0.37131096144503883],[0.9202073176901536,0.37131096144503883],[0.9202073176901536,0.4032867735905694],[0.12263792377304437,0.4032867735905694]],"notes":"function by systematically choosing input values from","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.11925480863447763,0.42612663940880563],[0.9041375207819616,0.42612663940880563],[0.9041375207819616,0.4496190728218485],[0.11925480863447763,0.4496190728218485]],"notes":"within an allowed set and computing the value","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12010058741911932,0.47180637104527784],[0.9464264600140458,0.47180637104527784],[0.9464264600140458,0.494646236863514],[0.12010058741911932,0.494646236863514]],"notes":"of the function. The generalization of optimization","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12010058741911932,0.5214015082505904],[0.8660774754730858,0.5214015082505904],[0.8660774754730858,0.5553350231805413],[0.12010058741911932,0.5553350231805413]],"notes":"theory and the techniques to other formulations","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.12348370255768606,0.5762171862143574],[0.876226820888786,0.5762171862143574],[0.876226820888786,0.6029724576014339],[0.12348370255768606,0.6029724576014339]],"notes":"constitutes a large area of applied mathematics.","imageWidth":1274,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642513000,"last_updated_at":1551642513000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/baab8a0c-2e0a-4a61-9e7c-69ca080baacd___page_0055.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.14320109467693395,0.28556282769363056],[0.9013245370842314,0.27128468630894903],[0.9030092558451365,0.3076290461972293],[0.13646221963331354,0.3193111618756051]],"notes":"Natural language processing (N.L.P) is a","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13477750087240845,0.3543575089107325],[0.9973535064558224,0.3387813546728981],[0.9990382252167275,0.30503302049092357],[0.13646221963331354,0.32190718758191084]],"notes":"subfield of computer science, information","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.14320109467693395,0.3855098173864013],[0.9990382252167275,0.36733763744226117],[0.9956687876949173,0.3387813546728981],[0.14151637591602886,0.35825154747019106]],"notes":"engineering and artificial intelligence","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13814693839421865,0.39070186879901275],[0.13309278211150333,0.4192581515683758],[0.9939840689340121,0.3984899459179299],[0.9973535064558224,0.368635650295414]],"notes":"concerned with the interactions between","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.1314080633505982,0.4296422543935987],[0.9552355374331948,0.40238398447738855],[0.9333341935414284,0.43223828009990445],[0.1314080633505982,0.4607945628692675]],"notes":"computers and human (natural)","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.13646221963331354,0.4672846271350318],[0.128038625828788,0.5023309741701593],[0.9653438499986254,0.4685826399881847],[0.9569202561940999,0.4348343058062102]],"notes":"languages, in particular how to","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.128038625828788,0.5062250127296178],[0.128038625828788,0.5425693726178981],[0.9249105997369029,0.5127150769953822],[0.9063786933669468,0.4789667428134076]],"notes":"program computers to process","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.12972334458969312,0.5477614240305095],[0.128038625828788,0.5750196939467197],[0.9788216000858663,0.563337578268344],[0.9602896937159101,0.5192051412611465],[0.958604974955005,0.514013089848535]],"notes":"and analyze large amounts of","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.128038625828788,0.582807771065637],[0.128038625828788,0.611364053835],[0.9788216000858663,0.5957878995971656],[0.9805063188467713,0.5646355911214969]],"notes":"natural language data. Challenges","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.11746056540276852,0.6589456729826468],[0.11746056540276852,0.6985389752219474],[0.9158253458747108,0.6674299520339255],[0.8993074538649465,0.6419771148800894]],"notes":"","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.11562524406835026,0.6179383242347997],[0.11195460139951374,0.6518754404399145],[0.9708849859072585,0.6334928358288107],[0.967214343238422,0.5995557196236958]],"notes":"in natural languages processing","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.113789922733932,0.6999530217304939],[0.113789922733932,0.7254058588843301],[0.9617083792351672,0.6999530217304939],[0.9470258085598211,0.6702580450510184]],"notes":"","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.113789922733932,0.7367182309527016],[0.11195460139951374,0.7593429750894449],[0.9066487392026195,0.7522727425467126],[0.8974721325305282,0.7112653937988656]],"notes":"","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.11746056540276852,0.7649991611236308],[0.11929588673718677,0.7961081843116526],[0.7286225697640485,0.8102486493971172],[0.7322932124328849,0.756514882072352]],"notes":"","imageWidth":1276,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642443000,"last_updated_at":1551642751000,"sec_taken":274,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/3061f91f-e31b-42d4-8323-885c7c41e90c___page_0041.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642196000,"last_updated_at":1551642748000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/9231ab59-5925-457e-b988-c44188514e8d___page_0096.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.08340495409000775,0.2661502149288549],[0.9035536693084173,0.2661502149288549],[0.9035536693084173,0.3155167870527554],[0.08340495409000775,0.3155167870527554]],"notes":"There are so many interesting problems to","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.08479503665817455,0.3198095324548337],[0.9007735041720838,0.3198095324548337],[0.9007735041720838,0.3627369864756167],[0.08479503665817455,0.3627369864756167]],"notes":"work on how do you pick the one to focus on","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.07923470638550736,0.37024929092925374],[0.8910429261949162,0.37024929092925374],[0.8910429261949162,0.4142499313005564],[0.07923470638550736,0.4142499313005564]],"notes":"right now? The is no clear-cut answer,","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.06949793359369577,0.42029385702368843],[0.9570444605298523,0.42029385702368843],[0.9570444605298523,0.4571813497943845],[0.06949793359369577,0.4571813497943845]],"notes":"yet some good rules of thumb apply. For instance,","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.07239368082676644,0.465005969473017],[0.9599402077629229,0.465005969473017],[0.9599402077629229,0.5074824762998791],[0.07239368082676644,0.5074824762998791]],"notes":"consider impact/feasibility framework. You always","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.0709458072102311,0.5164248987897448],[0.8310794558912786,0.5164248987897448],[0.8310794558912786,0.5510767859379745],[0.0709458072102311,0.5510767859379745]],"notes":"whant to be working on high impact and","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.06805005997716045,0.558901405616607],[0.5357132381180716,0.558901405616607],[0.5357132381180716,0.6069669264996351],[0.06805005997716045,0.6069669264996351]],"notes":"high feasibility problems.","imageWidth":1274,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642378000,"last_updated_at":1551642378000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/8d8193ed-1c92-467f-8d58-6b8a7ec1a4a7___page_0097.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.06920823696737605,0.2603155293182308],[0.9148760092262723,0.2603155293182308],[0.9148760092262723,0.29980159275414225],[0.06920823696737605,0.29980159275414225]],"notes":"","imageWidth":1274,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551641856000,"last_updated_at":1551641856000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/a1352ccc-e927-4869-b760-c54f8835190d___page_0054.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.05995203836930456,0.2680221811460259],[0.920863309352518,0.2680221811460259],[0.920863309352518,0.3031423290203327],[0.05995203836930456,0.3031423290203327]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.08872901678657075,0.31053604436229204],[0.9760191846522782,0.31053604436229204],[0.9760191846522782,0.34935304990757854],[0.08872901678657075,0.34935304990757854]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.11270983213429256,0.34935304990757854],[0.973621103117506,0.34935304990757854],[0.973621103117506,0.3900184842883549],[0.11270983213429256,0.3900184842883549]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.09352517985611511,0.38817005545286504],[0.9280575539568345,0.38817005545286504],[0.9280575539568345,0.4417744916820702],[0.09352517985611511,0.4417744916820702]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.11990407673860912,0.4454713493530499],[0.9568345323741008,0.4454713493530499],[0.9568345323741008,0.4787430683918669],[0.11990407673860912,0.4787430683918669]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.10311750599520383,0.4824399260628466],[0.9856115107913669,0.4824399260628466],[0.9856115107913669,0.5249537892791127],[0.10311750599520383,0.5249537892791127]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.1223021582733813,0.5415896487985212],[0.894484412470024,0.5415896487985212],[0.894484412470024,0.5822550831792976],[0.1223021582733813,0.5822550831792976]],"notes":"","imageWidth":1277,"imageHeight":1654},{"label":["line"],"shape":"rectangle","points":[[0.12709832134292565,0.5933456561922366],[0.6594724220623501,0.5933456561922366],[0.6594724220623501,0.6561922365988909],[0.12709832134292565,0.6561922365988909]],"notes":"","imageWidth":1277,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642491000,"last_updated_at":1551642491000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/acb4d474-0457-4fbb-b149-ef4d0d568dc4___page_0068.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11648745519713262,0.26832641770401106],[0.953405017921147,0.26832641770401106],[0.953405017921147,0.30428769017980634],[0.11648745519713262,0.30428769017980634]],"notes":"","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.10215053763440861,0.31950207468879666],[0.9336917562724014,0.31950207468879666],[0.9336917562724014,0.35408022130013833],[0.10215053763440861,0.35408022130013833]],"notes":"","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.11648745519713262,0.3651452282157676],[0.9372759856630825,0.3651452282157676],[0.9372759856630825,0.3997233748271093],[0.11648745519713262,0.3997233748271093]],"notes":"","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.10931899641577061,0.4107883817427386],[0.9121863799283154,0.4107883817427386],[0.9121863799283154,0.45228215767634855],[0.10931899641577061,0.45228215767634855]],"notes":"","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"rectangle","points":[[0.10215053763440861,0.4633471645919779],[0.989247311827957,0.4633471645919779],[0.989247311827957,0.5131396957123098],[0.10215053763440861,0.5131396957123098]],"notes":"","imageWidth":1274,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642481000,"last_updated_at":1551642481000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/49ba0b43-11fb-4b06-9183-f0dc66946f96___page_0108.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.15214929862493645,0.31794688190620146],[0.9471293839402295,0.29055453515735946],[0.9534689380496019,0.3394694400660059],[0.1597567635561833,0.37175327730571245]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.05959180862810012,0.3688183830111937],[0.9522010272277274,0.34240433436052464],[0.9522010272277274,0.37175327730571245],[0.05959180862810012,0.4089286050362837]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.05832389780622564,0.41577669172349424],[0.948397294762104,0.37762306589475003],[0.948397294762104,0.4069720088399379],[0.060859719449974585,0.45099542325771963]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.05705598698435117,0.45588691374858425],[0.9509331164058529,0.41088520123262956],[0.9522010272277274,0.4402341441778174],[0.050716432874978824,0.48034436620290744]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.05325225451872776,0.4881707509882909],[0.9179674350371166,0.45099542325771963],[0.9179674350371166,0.48425755859559916],[0.05705598698435117,0.5233894825225163]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.04944852205310435,0.5312158673078997],[0.9141637025714933,0.49110564528280964],[0.92177116750274,0.5165413958353058],[0.06212763027184905,0.5546950216640499],[0.05198434369685329,0.567412896940298]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.06212763027184905,0.5615431083512604],[0.991506262705836,0.5184979920316516],[0.9940420843495849,0.5556733197622229],[0.060859719449974585,0.593826945590967]],"notes":"","imageWidth":1277,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.05198434369685329,0.6036099265726963],[0.9953099951714593,0.5625214064494334],[0.9953099951714593,0.5899137531982753],[0.07353882766871929,0.6261107828306737],[0.050716432874978824,0.6407852543032676]],"notes":"","imageWidth":1277,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642504000,"last_updated_at":1551642504000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/b93495bd-d87c-48f2-8295-f7a90d211a79___page_0050.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642099000,"last_updated_at":1551642533000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/9df2d063-395e-4c65-8b97-8309a12f95f0___page_0078.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.11704411877612936,0.3437182440748323],[0.890022986526817,0.332448793449428],[0.8973382439503251,0.3690745079819921],[0.12435937619963744,0.3775265959510454]],"notes":"when the data source has a lower-probability value","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.07315257423508086,0.37940483772194605],[0.8717348429680468,0.37940483772194605],[0.8717348429680468,0.42260439845266273],[0.07315257423508086,0.42260439845266273]],"notes":"(i.e., when a low-probability event occurs), the event","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.10363281349969787,0.42636088199446415],[0.9070919205150025,0.42636088199446415],[0.9070919205150025,0.46016923387067715],[0.10363281349969787,0.46016923387067715]],"notes":"carries more \"information\" (\"surprisal\") than wehen the","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.06705652638215745,0.475195168037883],[0.4828069899515336,0.4639257174124786],[0.9314761119266961,0.4592301129852268],[0.9424489980619583,0.4911602230905391],[0.48036857081036427,0.498673190174142],[0.06827573595274213,0.5146382452267981]],"notes":"source data has a higher-probability value. The amount of","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.06827573595274213,0.5193338496540499],[0.5961934800159089,0.49961231105959236],[0.9387913693502042,0.5005514319450427],[0.9412297884913736,0.5362380255921565],[0.2974871352226621,0.5559595641866141],[0.0768102029468349,0.5578378059575148]],"notes":"information conveyed by each event defined in this way becomes","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"polygon","points":[[0.08046783165858894,0.5963417622609796],[0.9229416449326033,0.586950553406476],[0.9241608545031881,0.5418727509048586],[0.5510827259042758,0.5522030806448126],[0.0768102029468349,0.5653507730411177]],"notes":"a random variable whose expected value is the information","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.07924862208800425,0.5963417622609796],[0.9314761119266961,0.5963417622609796],[0.9314761119266961,0.6470542900752991],[0.07924862208800425,0.6470542900752991]],"notes":"entropy. Generally, entropy refers to disorder or uncertainty,","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.09022150822326638,0.6498716527316503],[0.954641093767805,0.6498716527316503],[0.954641093767805,0.6864973672642143],[0.09022150822326638,0.6864973672642143]],"notes":"and the definition of entropy used in information theory is","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.0890022986526817,0.6902538508060158],[0.8010206878741353,0.6902538508060158],[0.8010206878741353,0.7334534115367324],[0.0890022986526817,0.7334534115367324]],"notes":"directly analogous to the definition used in","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.11460569963495999,0.7362707741930835],[0.5096296005043965,0.7362707741930835],[0.5096296005043965,0.7747747304965483],[0.11460569963495999,0.7747747304965483]],"notes":"statistical thermodynamics.","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642165000,"last_updated_at":1551642651000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/e3301354-71bf-40a0-a5ff-5f7d98102573___page_0093.jpg","annotation":[],"extras":null,"metadata":{"first_done_at":1551641913000,"last_updated_at":1551641958000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/52197e17-80a2-40c6-8be5-0fb972141774___page_0087.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11210134568238246,0.3066743032522675],[0.932335839795026,0.3066743032522675],[0.932335839795026,0.33770682203374697],[0.11210134568238246,0.33770682203374697]],"notes":"Mathematicians seek and use patterns to formulate new conjectures; they resolve the truth or","imageWidth":1275,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.11368023787509207,0.3480509949609068],[0.8723379364720607,0.3480509949609068],[0.8723379364720607,0.37360718689859573],[0.11368023787509207,0.37360718689859573]],"notes":"falsity of conjectures by mathematical proof. When mathematical structures are good models of","imageWidth":1275,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.9670714680346374,0.3815174367840709],[0.11289079177873727,0.3815174367840709],[0.11289079177873727,0.4119414748051292],[0.9670714680346374,0.4119414748051292]],"notes":"real phenomena, then mathematical reasoning can provide insight or predictions about nature.","imageWidth":1275,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.8960213193627049,0.41680932088849854],[0.11131189958602766,0.41680932088849854],[0.11131189958602766,0.45575208955545315],[0.8960213193627049,0.45575208955545315]],"notes":"Through the use of abstraction and logic, mathematics developed from counting, calculation,","imageWidth":1275,"imageHeight":1656},{"label":["line"],"shape":"rectangle","points":[[0.8983896576517694,0.4612284163992436],[0.11210134568238246,0.4612284163992436],[0.11210134568238246,0.4916524544203019],[0.8983896576517694,0.4916524544203019]],"notes":"measurement, ans the systematic study of the shapes and motions of physical objects.","imageWidth":1275,"imageHeight":1656}],"extras":null,"metadata":{"first_done_at":1551642260000,"last_updated_at":1551642262000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/4587050e-e1c1-47fc-8faf-4508e4471ade___page_0092.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.09429280397022333,0.32122370936902483],[0.9156327543424317,0.30975143403441685],[0.9205955334987593,0.33460803059273425],[0.09429280397022333,0.35372848948374763],[0.09429280397022333,0.37093690248565964]],"notes":"Computer Science is the study of process that interact with data and","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.09181141439205956,0.37093690248565964],[0.9181141439205955,0.35181644359464626],[0.9181141439205955,0.37667304015296366],[0.09925558312655088,0.39579349904397704]],"notes":"that can be represented as data in the form of programs. It enables the","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.09925558312655088,0.4130019120458891],[0.9652605459057072,0.3919694072657744],[0.9627791563275434,0.4187380497131931],[0.10173697270471464,0.4359464627151052]],"notes":"use of algorithms to manipulate. Store, and communicate digital information","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.0967741935483871,0.45315487571701724],[0.9354838709677419,0.4321223709369025],[0.9330024813895782,0.4569789674952199],[0.10421836228287841,0.47992351816443596]],"notes":"A computer scientist studies the theory of computation and the","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.10421836228287841,0.4933078393881453],[0.8833746898263027,0.47418738049713194],[0.8808933002481389,0.49521988527724664],[0.10669975186104218,0.5162523900573613]],"notes":"theoritical and practical disciplines. Computational complexity theory","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.10669975186104218,0.5315487571701721],[0.9379652605459057,0.51434034416826],[0.9354838709677419,0.5449330783938815],[0.11166253101736973,0.5583173996175909]],"notes":"theoretical and practical disciplines. Computational complexity theory","imageWidth":1273,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.10421836228287841,0.5621414913957935],[0.9330024813895782,0.5506692160611855],[0.9429280397022333,0.5889101338432122],[0.11662531017369727,0.6061185468451242]],"notes":"is highly abstract, while computer graphics emphasizes real world applications","imageWidth":1273,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642751000,"last_updated_at":1551642751000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/c63becc8-5ce8-470e-b996-f95a8308f311___page_0079.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.05394486132843181,0.33062331179310883],[0.890090211919125,0.3145644652203007],[0.8925422510704173,0.34101433016374944],[0.05762292005537035,0.356128538702863]],"notes":"Deep learning (also known as deep structured learning or hierarchical learning)","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.07143527205309437,0.3651997725230968],[0.45467585320280335,0.3562743198341413],[0.8398471173809743,0.34809265486926544],[0.8369510928382813,0.37263764976389313],[0.41992355869048714,0.39346370603812264],[0.07240061356732537,0.392719918314043]],"notes":"is a part of a broader family of machine learning methods based on","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.07170017871567709,0.4009404181386089],[0.8995113329784944,0.3783404363228627],[0.9049431646993791,0.41182189086470894],[0.6637698362921016,0.4210292908637167],[0.3606736262667393,0.4260515090449936],[0.12601849592452338,0.4360959454075475],[0.07170017871567709,0.4302366908627244]],"notes":"learning data representations, as opposed to task-specific algorithms. Learning","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.07278654505985402,0.43944409086173214],[0.9386205213688638,0.41182189086470894],[0.9397068877130407,0.44279223631591674],[0.7941337975933327,0.4553477817691091],[0.07278654505985402,0.46622925449520913]],"notes":"can be supervised, semi-supervised or unsupervised. Deep learning architectures","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.07387291140403095,0.48296998176613226],[0.5497013701535244,0.46622925449520913],[0.9103749964202636,0.45953296358683987],[0.9027704320110252,0.4804588726754938],[0.5171103798282166,0.49887367267350924],[0.2118414371145005,0.5147773635808862],[0.07495927774820788,0.5147773635808862]],"notes":"such as deep neural networks, deep belief networks and recurrent neural","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.07713201043656172,0.5290069817611709],[0.46496479530772417,0.508081072672517],[0.7300381832868941,0.4997107090370554],[0.9386205213688638,0.4938514544922323],[0.9386205213688638,0.5206366181257093],[0.6040196873623707,0.5382143817601786],[0.3150462398113084,0.5490958544862786],[0.07821837678073865,0.5549551090311017]],"notes":"networks have been applied to fields including computer vision, speech recognition,","imageWidth":1274,"imageHeight":1651},{"label":["line"],"shape":"polygon","points":[[0.08147747581326942,0.567510654484294],[0.6463879747852708,0.545747709032094],[0.6507334401619784,0.577555090846848],[0.3595872599225624,0.5942958181177711],[0.0803911094690925,0.5942958181177711]],"notes":"natural language processing and audio recognition.","imageWidth":1274,"imageHeight":1651}],"extras":null,"metadata":{"first_done_at":1551642333000,"last_updated_at":1551642333000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/60fd45a9-95ba-4a0d-8194-310423ab0aa2___page_0053.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.09855532502582025,0.3529281588444158],[0.5081758946643856,0.3529281588444158],[0.5081758946643856,0.31133729837453517],[0.09855532502582025,0.31133729837453517]],"notes":"","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.512795675524971,0.35411646914355527],[0.49739640598968654,0.35411646914355527],[0.49739640598968654,0.3731294339297864],[0.512795675524971,0.3731294339297864]],"notes":"","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.24484838561102218,0.42066184589536426],[0.0970153980722918,0.42066184589536426],[0.0970153980722918,0.337480124955603],[0.24484838561102218,0.337480124955603]],"notes":"","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.2494681664716075,0.3362918146564635],[0.9655341998623327,0.3362918146564635],[0.9655341998623327,0.41709691499794593],[0.2494681664716075,0.41709691499794593]],"notes":"","imageWidth":1276,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.10779488674699089,0.4218501561945037],[0.9778536154905603,0.4218501561945037],[0.9778536154905603,0.6250512173473491],[0.10779488674699089,0.6250512173473491]],"notes":"","imageWidth":1276,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642397000,"last_updated_at":1551642397000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/742cee36-5414-4f6c-aa62-82c1d1065c87___page_0084.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.14214461210211474,0.2920787712271451],[0.8656116723183953,0.2920787712271451],[0.8656116723183953,0.32151306600197366],[0.14214461210211474,0.32151306600197366]],"notes":"Information is the resolution of uncertainty; it is that","imageWidth":1273,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642328000,"last_updated_at":1551642328000,"sec_taken":271,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"INCORRECT"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/75d834a1-3405-46d9-a2a3-0023693a18ee___page_0090.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.11528570550666367,0.5809108384394237],[0.932487918876908,0.5809108384394237],[0.932487918876908,0.6296476070229748],[0.11528570550666367,0.6296476070229748]],"notes":"while negative values mean inhibitory connections.","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.09692161082418628,0.5313879929432347],[0.8794360897941954,0.5313879929432347],[0.8794360897941954,0.57855260770151],[0.09692161082418628,0.57855260770151]],"notes":"A positive weight reflects an excitatory connection,","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.15507457731869803,0.4952284549618903],[0.2938255149196384,0.4952284549618903],[0.2938255149196384,0.5306019160305967],[0.15507457731869803,0.5306019160305967]],"notes":"weights.","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.09488115585946656,0.45278030167944255],[0.939629511253427,0.45278030167944255],[0.939629511253427,0.4952284549618903],[0.09488115585946656,0.4952284549618903]],"notes":"the connection of the biological neuron are modeled as","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.12140707040082281,0.4024713792706156],[0.5386801106860037,0.4024713792706156],[0.5386801106860037,0.4433473787277875],[0.12140707040082281,0.4433473787277875]],"notes":"intelligence (AI) problems.","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.11834638795374323,0.3655257643766333],[0.9212654165709495,0.3655257643766333],[0.9212654165709495,0.40718784074644315],[0.11834638795374323,0.40718784074644315]],"notes":"an artificial neural network, for solving artificial","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.12344752536554252,0.328580149482651],[0.903921549370832,0.328580149482651],[0.903921549370832,0.37024222585246086],[0.12344752536554252,0.37024222585246086]],"notes":"network, made up of real biological neurons, or","imageWidth":1274,"imageHeight":1653},{"label":["line"],"shape":"rectangle","points":[[0.07141592376518989,0.289276303850755],[0.9080024593002715,0.289276303850755],[0.9080024593002715,0.34037130317221986],[0.07141592376518989,0.34037130317221986]],"notes":"Thus a neural network is either a biological neural","imageWidth":1274,"imageHeight":1653}],"extras":null,"metadata":{"first_done_at":1551642379000,"last_updated_at":1551642488000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/e92f801d-0e36-4516-96ab-cf06f0ebcc0c___page_0091.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.16170212765957448,0.3567921440261866],[0.4808510638297872,0.3436988543371522],[0.7148936170212766,0.33878887070376434],[0.9680851063829787,0.3404255319148936],[0.9744680851063829,0.3027823240589198],[0.1702127659574468,0.3158756137479542],[0.15106382978723404,0.3142389525368249]],"notes":"","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.16808510638297872,0.39279869067103107],[0.16808510638297872,0.36006546644844517],[0.7,0.34206219312602293],[0.9148936170212766,0.3436988543371522],[0.9148936170212766,0.36824877250409166]],"notes":"","imageWidth":1274,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.16595744680851063,0.3960720130932897],[0.9744680851063829,0.36824877250409166],[0.9723404255319149,0.40589198036006546],[0.16170212765957448,0.4320785597381342],[0.15531914893617021,0.4369885433715221],[0.16382978723404254,0.4386252045826514],[0.16382978723404254,0.4386252045826514]],"notes":"","imageWidth":1274,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551597003000,"last_updated_at":1551597003000,"sec_taken":83,"last_updated_by":"69FI7aSdl6aSMhn3Anp3BRvA8gg2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/a0153bd0-62f9-480c-92a7-e9db47b4053d___page_0085.jpg","annotation":[{"label":["line"],"shape":"rectangle","points":[[0.9046920211971423,0.46804044558785357],[0.9976298853329644,0.46804044558785357],[0.9976298853329644,0.5272436523089902],[0.9046920211971423,0.5272436523089902]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.9295454545454546,0.5052631578947369],[0.9431818181818182,0.5052631578947369],[0.9431818181818182,0.49298245614035085],[0.9295454545454546,0.49298245614035085]],"notes":"","imageWidth":1274,"imageHeight":1652},{"label":["line"],"shape":"rectangle","points":[[0.9659090909090909,0.5157894736842106],[0.9772727272727273,0.5157894736842106],[0.9772727272727273,0.5],[0.9659090909090909,0.5]],"notes":"","imageWidth":1274,"imageHeight":1652}],"extras":null,"metadata":{"first_done_at":1551642170000,"last_updated_at":1551642483000,"sec_taken":142,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/f1614407-b739-4cf7-8495-9df7e67da4c4___page_0052.jpg","annotation":[{"label":["line"],"shape":"polygon","points":[[0.07493391582194713,0.3422201463385958],[0.8308297916758387,0.31550675938811473],[0.8270830958847413,0.2887933724376336],[0.06931387213530109,0.3111748588015502]],"notes":"There are many solutions around is necessary to","imageWidth":1276,"imageHeight":1654},{"label":["line"],"shape":"polygon","points":[[0.08149063345636749,0.3443860966318781],[0.87860016301233,0.31550675938811473],[0.8832835327512016,0.34583006349406625],[0.0889840250385622,0.36460163270251245]],"notes":"explore. using different techniques in order to get better solutions","imageWidth":1276,"imageHeight":1654}],"extras":null,"metadata":{"first_done_at":1551642334000,"last_updated_at":1551642478000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} {"content": "http://com.dataturks.a96-i23.open.s3.amazonaws.com/2c9fafb068c19d77016931d84bde07ba/633887b3-356f-44b3-b602-abb6de17bc99___page_0046.jpg","annotation":null,"extras":null,"metadata":{"first_done_at":1551642488000,"last_updated_at":1551642488000,"sec_taken":0,"last_updated_by":"qYxNqy3ztcMtWhTynXyHGxbAArx2","status":"done","evaluation":"NONE"}} ================================================ FILE: data/raw/fsdl_handwriting/manifest.csv ================================================ form https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-001.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-002.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-003.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-004.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-005.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-006.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-007.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-008.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-009.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-010.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-011.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-012.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-013.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-014.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-015.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-016.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-017.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-018.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-019.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-020.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-021.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-022.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-023.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-024.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-025.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-026.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-027.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-028.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-029.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-030.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-031.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-032.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-033.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-034.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-035.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-036.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-037.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-038.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-039.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-040.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-041.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-042.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-043.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-044.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-045.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-046.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-047.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-048.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-049.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-050.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-051.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-052.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-053.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-054.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-055.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-056.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-057.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-058.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-059.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-060.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-061.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-062.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-063.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-064.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-065.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-066.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-067.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-068.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-069.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-070.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-071.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-072.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-073.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-074.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-075.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-076.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-077.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-078.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-079.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-080.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-081.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-082.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-083.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-084.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-085.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-086.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-087.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-088.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-089.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-090.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-091.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-092.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-093.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-094.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-095.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-096.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-097.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-098.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-099.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-100.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-101.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-102.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-103.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-104.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-105.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-106.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-107.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-108.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-109.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-110.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-111.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-112.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-113.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-114.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-115.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-116.jpg https://fsdl-public-assets.s3.us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-117.jpg ================================================ FILE: data/raw/fsdl_handwriting/metadata.toml ================================================ url = "https://dataturks.com/projects/sergeykarayev/fsdl_handwriting/export" filename = "fsdl_handwriting.json" sha256 = "720d6c72b4317a9a5492630a1c9f6d83a20d36101a29311a5cf7825c1d60c180" ================================================ FILE: data/raw/fsdl_handwriting/readme.md ================================================ # FSDL Handwriting Dataset ## Collection Handwritten paragraphs were collected in the FSDL March 2019 class. The resulting PDF was stored at https://fsdl-public-assets.s3-us-west-2.amazonaws.com/fsdl_handwriting_20190302.pdf Pages were extracted from the PDF by running `gs -q -dBATCH -dNOPAUSE -sDEVICE=jpeg -r300 -sOutputFile=page-%03d.jpg -f fsdl_handwriting_20190302.pdf` and uploaded to S3, with urls like https://fsdl-public-assets.s3-us-west-2.amazonaws.com/fsdl_handwriting_20190302/page-001.jpg ================================================ FILE: data/raw/iam/metadata.toml ================================================ url = 'https://s3-us-west-2.amazonaws.com/fsdl-public-assets/iam/iamdb.zip' filename = 'iamdb.zip' sha256 = 'f3c9e87a88a313e557c6d3548ed8a2a1af2dc3c4a678c5f3fc6f972ba4a50c55' ================================================ FILE: data/raw/iam/readme.md ================================================ # IAM Dataset The IAM Handwriting Database contains forms of handwritten English text which can be used to train and test handwritten text recognizers and to perform writer identification and verification experiments. - 657 writers contributed samples of their handwriting - 1,539 pages of scanned text - 13,353 isolated and labeled text lines - http://www.fki.inf.unibe.ch/databases/iam-handwriting-database ## Pre-processing First, all forms were placed into one directory called `forms`, from original directories like `formsA-D`. To save space, I converted the original PNG files to JPG, and resized them to half-size ``` mkdir forms-resized cd forms ls -1 *.png | parallel --eta -j 6 convert '{}' -adaptive-resize 50% '../forms-resized/{.}.jpg' ``` ## Split The data split we will use is loosely based on the IAM Lines Large Writer Independent Text Line Recognition Task (`lwitlrt`) which provides 4 data splits: - Train: has 6,161 text lines from 747 pages written by 283 writers - Validation 1: has 900 text lines from 105 pages written by 46 writers - Validation 2: has 940 text lines from 115 pages written by 43 writers - Test: has 1,861 text lines from 232 pages written by 128 writers Total: has 9,862 text lines from 1199 pages written by 500 writers The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. The total text lines (9,862) in the data splits is way less then all the text lines (13,353) in the dataset. This is because: - pages of 157 writers (`657-500`) are not included in the data splits - 511 text lines are dropped from the 1,199 pages included in the data splits To avoid missing out on all the dropped data, we slightly modify the data splits. We: - use all text lines in a page and never drop text lines - merge Validation 1 and Validation 2 into a single Validation data split - all the missing pages of 157 writers are added to the train data split. Our final data splits are: - Train: has 9,462 text lines from 1,087 pages written by 440 writers - Validation: has 1,926 text lines from 220 pages written by 89 writers - Test: has 1,965 text lines from 232 pages written by 128 writers Total: has 13,353 text lines from 1,539 pages written by 657 writers ================================================ FILE: environment.yml ================================================ name: fsdl-text-recognizer-2022 channels: - pytorch - nvidia - defaults dependencies: - python=3.10 # versioned to match Google Colab # version also pinned in Dockerfile - pytorch=2.1.1 # versioned to match Google Colab - pytorch-cuda=12.1 # versioned to match Google Colab - pip=23.1.2 # versioned to match Google Colab # version also pinned in Dockerfile ================================================ FILE: lab01/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab01/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab01/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab01/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab01/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab01/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP ================================================ FILE: lab01/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab01/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab02/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab02/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab02/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab02/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab02/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines ================================================ FILE: lab02/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab02/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab02/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab02/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab02/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab02/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab02/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab02/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel ================================================ FILE: lab02/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) ================================================ FILE: lab02/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab02/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab02/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab02/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab02/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple ================================================ FILE: lab02/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab02/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab02/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab02/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab02/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab02/training/__init__.py ================================================ ================================================ FILE: lab02/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab02/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: lab03/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab03/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab03/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab03/notebooks/lab03_transformers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 03: Transformers and Paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The fundamental reasons why the Transformer is such\n", "a powerful and popular architecture\n", "- Core intuitions for the behavior of Transformer architectures\n", "- How to use a convolutional encoder and a Transformer decoder to recognize\n", "entire paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 3\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Transformers?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal in building a text recognizer is to take a two-dimensional image\n", "and convert it into a one-dimensional sequence of characters\n", "from some alphabet." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional neural networks,\n", "discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n", "are great at encoding images,\n", "taking them from their raw pixel values\n", "to a more semantically meaningful numerical representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But how do we go from that to a sequence of letters?\n", "And what's especially tricky:\n", "the number of letters in an image is separable from its size.\n", "A screenshot of this document has a much higher density of letters\n", "than a close-up photograph of a piece of paper.\n", "How do we get a _variable-length_ sequence of letters,\n", "where the length need have nothing to do with the size of the input tensor?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n", "they were\n", "[originally introduced](https://arxiv.org/abs/1706.03762)\n", "for transforming one sequence into another,\n", "as in machine translation.\n", "This makes them a natural fit for processing language.\n", "\n", "But they have also found success in other domains --\n", "at the time of this writing, large transformers\n", "dominate the\n", "[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n", "that has become a de facto standard for comparing models\n", "and are finding\n", "[application in reinforcement learning](https://arxiv.org/abs/2106.01345)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we will use a Transformer as a key component of our final architecture:\n", "we will encode our input images with a CNN\n", "and then read them out into a text sequence with a Transformer.\n", "\n", "Before trying out this new model,\n", "let's first get an understanding of why the Transformer architecture\n", "has become so popular by walking through its history\n", "and then get some intuition for how it works\n", "by looking at some\n", "[recent work](https://transformer-circuits.pub/)\n", "on explaining the behavior of both toy models and state-of-the-art language models." ] }, { "cell_type": "markdown", "metadata": { "id": "kmKqjbvd-Mj3" }, "source": [ "## Why not convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "SRqkUMdM-OxU" }, "source": [ "In the ancient beforetimes (i.e. 2016),\n", "the best models for natural language processing were all\n", "_recurrent_ neural networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional networks were also occasionally used,\n", "but they suffered from a serious issue:\n", "their architectural biases don't fit text.\n", "\n", "First, _translation equivariance_ no longer holds.\n", "The beginning of a piece of text is often quite different from the middle,\n", "so the absolute position matters.\n", "\n", "Second, _locality_ is not as important in language.\n", "The name of a character that hasn't appeared in thousands of pages\n", "can become salient when someone asks, \"Whatever happened to\n", "[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n", "\n", "Consider interpreting a piece of text like the Python code below:\n", "```python\n", "def do(arg1, arg2, arg3):\n", " a = arg1 + arg2\n", " b = arg3[:3]\n", " c = a * b\n", " return c\n", "\n", "print(do(1, 1, \"ayy lmao\"))\n", "```\n", "\n", "After a `(` we expect a `)`,\n", "but possibly very long afterwards,\n", "[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n", "and similarly we expect a `]` at some point after a `[`.\n", "\n", "For translation variance, consider\n", "that we interpret `*` not by\n", "comparing it to its neighbors\n", "but by looking at `a` and `b`.\n", "We mix knowledge learned through experience\n", "with new facts learned while reading --\n", "also known as _in-context learning_.\n", "\n", "In a longer text,\n", "[e.g. the one you are reading now](./lab03_transformers.ipynb),\n", "the translation variance of text is clearer.\n", "Every lab notebook begins with the same header,\n", "setting up the environment,\n", "but that header never appears elsewhere in the notebook.\n", "Later positions need to be processed in terms of the previous entries.\n", "\n", "Unlike an image, we cannot simply rotate or translate our \"camera\"\n", "and get a new valid text.\n", "[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n", "that can be read without regard to position." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The field of formal language theory,\n", "which has deep mutual influence with computer science,\n", "gives one way of explaining the issues with convolutional networks:\n", "they can only understand languages with _finite contexts_,\n", "where all the information can be found within a finite window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The immediate solution, drawing from the connections to computer science, is\n", "[recursion](https://www.google.com/search?q=recursion).\n", "A network whose output on the final entry of the sequence is a recursive function\n", "of all the previous entries can build up knowledge\n", "as it reads the sequence and treat early entries quite differently than it does late ones." ] }, { "cell_type": "markdown", "metadata": { "id": "aa6cbTlImkEh" }, "source": [ "In pseudo-code, such a _recurrent neural network_ module might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "lKtBoPnglPrW" }, "source": [ "```python\n", "def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n", " next_inputs = input_module(xs[-1])\n", " next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n", " return output_module(next_inputs, next_hiddens)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "IbJPSMnEm516" }, "source": [ "If you've had formal computer science training,\n", "then you may be familiar with the power of recursion,\n", "e.g. the\n", "[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n", "that gave its name to the now much better-known\n", "[startup incubator](https://www.ycombinator.com/).\n", "\n", "The particular form of recursion used by\n", "recurrent neural networks implements a\n", "[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n", "\n", "> If you've know a lot of computer science,\n", "you might be concerned by this connection.\n", "What about other\n", "[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n", "Where are the neural network architectures for differentiable\n", "[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n", "Check out Graph Neural Networks,\n", "[which implement dynamic programming](https://arxiv.org/abs/2203.15544)." ] }, { "cell_type": "markdown", "metadata": { "id": "63mMTbEBpVuE" }, "source": [ "Recurrent networks are able to achieve\n", "[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n", "\n", "There are many popular recurrent architectures,\n", "from the beefy and classic\n", "[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n", "and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n", "([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n", "all of which have roughly similar capabilities but\n", "[some of which are easier to train](https://arxiv.org/abs/1611.09913)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwQHVTIslOku" }, "source": [ "In the same sense that MLPs can model \"any\" feedforward function,\n", "in principle even basic RNNs\n", "[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n", "\n", "In particular they can model any\n", "[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n", "which is a formal way of saying that they can in principle\n", "do anything a computer is capable of doing.\n", "\n", "The question is then..." ] }, { "cell_type": "markdown", "metadata": { "id": "3J8EoGN3pu7P" }, "source": [ "## Why aren't we all using RNNs?" ] }, { "cell_type": "markdown", "metadata": { "id": "TDwNWaevpt_3" }, "source": [ "The guarantees that MLPs can model any function\n", "or that RNNs can model Turing machines\n", "provide decent intuition but are not directly practically useful.\n", "Among other reasons, they don't guarantee learnability --\n", "that starting from random parameters we can find the parameters\n", "that implement a given function.\n", "The\n", "[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n", "than would seem from basic theoretical and empirical analysis.\n", "\n", "One way of understanding capacity to model language is\n", "[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n", "In this model of formal languages,\n", "Turing machines sit at the top\n", "([practically speaking](https://arxiv.org/abs/math/0209332)).\n", "\n", "With better mathematical models,\n", "RNNs and LSTMs can be shown to be\n", "[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n", "with RNNs looking more like\n", "[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n", "and LSTMs coming in\n", "[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n", "\n", "More controversially:\n", "the Chomsky hierarchy is great for understanding syntax and grammar,\n", "which makes it great for building parsers\n", "and working with formal languages,\n", "but the goal in _natural_ language processing is to understand _natural_ language.\n", "Most humans' natural language is far from strictly grammatical,\n", "but that doesn't mean it is nonsense.\n", "\n", "And to really \"understand\" language means\n", "to understand its semantic content, which is fuzzy.\n", "The most important thing for handling the fuzzy semantic content\n", "of language is not whether you can recall\n", "[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n", "but whether you can model probabilistic relationships between concepts\n", "in addition to grammar and syntax." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These both leave theoretical room for improvement over current recurrent\n", "language and sequence models.\n", "\n", "But the real cause of the rise of Transformers is that..." ] }, { "cell_type": "markdown", "metadata": { "id": "Dsu1ebvAp-3Z" }, "source": [ "## Transformers are designed to train fast at scale on contemporary hardware." ] }, { "cell_type": "markdown", "metadata": { "id": "c4abU5adsPGs" }, "source": [ "The Transformer architecture has several important features,\n", "discussed below,\n", "but one of the most important reasons why it is successful\n", "is because it can be more easily trained at scale.\n", "\n", "This scalability is the focus of the discussion in the paper\n", "that introduced the architecture,\n", "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", "and\n", "[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n", "\n", "The recursion in RNNs is inherently sequential:\n", "the dependence on the outputs from earlier in the sequence\n", "means computations within an example cannot be parallelized.\n", "\n", "So RNNs must batch across examples to scale,\n", "but as sequence length grows this hits memorybandwidth limits.\n", "Serving up large batches quickly with good randomness guarantees\n", "is also hard to optimize,\n", "especially in distributed settings.\n", "\n", "The Transformer architecture,\n", "on the other hand,\n", "can be readily parallelized within a single example sequence,\n", "in addition to parallelization across batches.\n", "This can lead to massive performance gains for a fixed scale,\n", "which means larger, higher capacity models\n", "can be trained on larger datasets." ] }, { "cell_type": "markdown", "metadata": { "id": "_Mzk2haFC_G1" }, "source": [ "How does the architecture achieve this parallelizability?\n", "\n", "Let's start with the architecture diagram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u59eu4snLQfp" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n", "\n", "display.Image(url=base_url + \"/aiayn-figure-1.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ez-XEQ7M0UlR" }, "source": [ "> To head off a bit of confusion\n", " in case you've worked with Transformer architectures before:\n", " the original \"Transformer\" is an encoder/decoder architecture.\n", " Many LLMs, like GPT models, are decoder only,\n", " because this has turned out to scale well,\n", " and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n", " it's all text anyways.\n", " We, however, will be using them across modalities,\n", " so we need an explicit encoder,\n", " as above. " ] }, { "cell_type": "markdown", "metadata": { "id": "ok4ksBi4vp89" }, "source": [ "First focusing on the encoder (left):\n", "the encoding at a given position is a function of all previous inputs.\n", "But it is not a function of the previous _encodings_:\n", "we produce the encodings \"all at once\"." ] }, { "cell_type": "markdown", "metadata": { "id": "RPN7C-_OqzHP" }, "source": [ "The decoder (right) does use previous \"outputs\" as its inputs,\n", "but those outputs are not the vectors of layer activations\n", "(aka embeddings)\n", "that are produced by the network.\n", "They are instead the processed outputs,\n", "after a `softmax` and an `argmax`.\n", "\n", "We could obtain these outputs by processing the embeddings,\n", "much like in a recurrent architecture.\n", "In fact, that is one way that Transformers are run.\n", "It's what happens in the `.forward` method\n", "of the model we'll be training for character recognition:\n", "`ResnetTransformer`." ] }, { "cell_type": "markdown", "metadata": { "id": "L5_2WMmtDnJn" }, "source": [ "Let's look at that forward method\n", "and connect it to the diagram." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FR5pk4kEyCGg" }, "outputs": [], "source": [ "from text_recognizer.models import ResnetTransformer\n", "\n", "\n", "ResnetTransformer.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "-J5UFDoPzPbq" }, "source": [ "`.encode` happens first -- that's the left side of diagram.\n", "\n", "The encoder can in principle be anything\n", "that produces a sequence of fixed-length vectors,\n", "but here it's\n", "[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n", "\n", "Then we start iterating over the sequence\n", "in the `for` loop.\n", "\n", "Focus on the first few lines of code.\n", "We apply `.decode` (right side of diagram)\n", "to the outputs so far.\n", "\n", "Once we have a new `output`, we apply `.argmax`\n", "to turn the logits into a concrete prediction of\n", "a particular token.\n", "\n", "This is added as the last output token\n", "and then the loop happens again." ] }, { "cell_type": "markdown", "metadata": { "id": "LTcy8-rV1dHr" }, "source": [ "Run this way, our model looks very much like a recurrent architecture:\n", "we call the model on its own outputs\n", "to generate the next value.\n", "These types of models are also referred to as\n", "[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n", "because we predict (as we do in _regression_)\n", "the next value based on our own (_auto_) output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But Transformers are designed to be _trained_ more scalably than RNNs,\n", "not necessarily to _run inference_ more scalably,\n", "and it's actually not the case that our model's `.forward` is called during training." ] }, { "cell_type": "markdown", "metadata": { "id": "eCxMSAWmEKBt" }, "source": [ "Let's look at what happens during training\n", "by checking the `training_step`\n", "of the `LightningModule`\n", "we use to train our Transformer models,\n", "the `TransformerLitModel`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0o7q8N7P2w4H" }, "outputs": [], "source": [ "from text_recognizer.lit_models import TransformerLitModel\n", "\n", "TransformerLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "1VgNNOjvzC4y" }, "source": [ "Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`." ] }, { "cell_type": "markdown", "metadata": { "id": "tz-6NGPR4dUr" }, "source": [ "Let's look at `.teacher_forward`,\n", "and in particular its type signature:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ILc2oWET4i2Z" }, "outputs": [], "source": [ "TransformerLitModel.teacher_forward??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`." ] }, { "cell_type": "markdown", "metadata": { "id": "lf32lpgrDb__" }, "source": [ "This is known as \"teacher forcing\".\n", "The \"teacher\" signal is \"forcing\"\n", "the model to behave as though\n", "it got the answer right.\n", "\n", "[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n", "It's more effective here\n", "because the right teaching signal\n", "for our network is the target data,\n", "which we have access to during training,\n", "whereas in an RNN the best teaching signal\n", "would be the target embedding vector,\n", "which we do not know.\n", "\n", "During inference, when we don't have access to the ground truth,\n", "we revert to the autoregressive `.forward` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This \"trick\" allows Transformer architectures to readily scale\n", "up models to the parameter counts\n", "[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)." ] }, { "cell_type": "markdown", "metadata": { "id": "BAjqpJm9uUuU" }, "source": [ "## Is there more to Transformers more than just a training trick?" ] }, { "cell_type": "markdown", "metadata": { "id": "kWCYXeHv7Qc9" }, "source": [ "[Very](https://arxiv.org/abs/2005.14165),\n", "[very](https://arxiv.org/abs/1909.08053),\n", "[very](https://arxiv.org/abs/2205.01068)\n", "large Transformer models have powered the most recent wave of exciting results in ML, like\n", "[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n", "\n", "They are also the first machine learning models to have come anywhere close to\n", "deserving the term _artificial intelligence_ --\n", "a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is surprising because the models and their training procedure are\n", "(relatively speaking)\n", "pretty _simple_,\n", "even if it doesn't feel that way on first pass." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic Transformer architecture is just a bunch of\n", "dense matrix multiplications and non-linearities --\n", "it's perhaps simpler than a convolutional architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And advances since the introduction of Transformers in 2017\n", "have not in the main been made by\n", "creating more sophisticated model architectures\n", "but by increasing the scale of the base architecture,\n", "or if anything making it simpler, as in\n", "[GPT-type models](https://arxiv.org/abs/2005.14165),\n", "which drop the encoder." ] }, { "cell_type": "markdown", "metadata": { "id": "V1HQS9ey8GMc" }, "source": [ "These models are also trained on very simple tasks:\n", "most LLMs are just trying to predict the next element in the sequence,\n", "given the previous elements --\n", "a task simple enough that Claude Shannon,\n", "father of information theory, was\n", "[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n", "\n", "These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n", "e.g. by scraping the web." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They are also trained in a simple fashion:\n", "first-order stochastic optimizers, like SGD or an\n", "[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n", "intended for the most basic of optimization problems,\n", "that scale more readily than the second-order optimizers\n", "that dominate other areas of optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "Kz9HPDoy7OAl" }, "source": [ "This is\n", "[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n", "of work in ML:\n", "simple, even seemingly wasteful,\n", "architectures that scale well and are robust\n", "to implementation details\n", "eventually outstrip more clever but\n", "also more finicky approaches that are harder to scale.\n", "This lesson has led some to declare that\n", "[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n", "in machine learning, and perhaps even in artificial intelligence." ] }, { "cell_type": "markdown", "metadata": { "id": "SdN9o2Y771YZ" }, "source": [ "> That is not to say that because the algorithms are relatively simple,\n", " training a model at this scale is _easy_ --\n", " [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n", " [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n", " [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n", " But choosing the simplest algorithm at every step makes solving the scaling problem feasible." ] }, { "cell_type": "markdown", "metadata": { "id": "baVGf6gKFOvs" }, "source": [ "The importance of scale is the key lesson from the Transformer architecture,\n", "far more than any theoretical considerations\n", "or any of the implementation details.\n", "\n", "That said, these large Transformer models are capable of\n", "impressive behaviors and understanding how they achieve them\n", "is of intellectual interest.\n", "Furthermore, like any architecture,\n", "there are common failure modes,\n", "of the model and of the modelers who use them,\n", "that need to be taken into account." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "Below, we'll cover two key intuitions about Transformers:\n", "Transformers are _residual_, like ResNets,\n", "and they compose _low rank_ sequence transformations.\n", "Together, this means they act somewhat like a computer,\n", "reading from and writing to a \"tape\" or memory\n", "with a sequence of simple instructions." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "We'll also cover a surprising implementation detail:\n", "despite being commonly used for sequence modeling,\n", "by default the architecture is _position insensitive_." ] }, { "cell_type": "markdown", "metadata": { "id": "uni0VTCr9lev" }, "source": [ "### Intuition #1: Transformers are highly residual." ] }, { "cell_type": "markdown", "metadata": { "id": "0MoBt-JLJz-d" }, "source": [ "> The discussion of these inuitions summarizes the discussion in\n", "[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n", "from\n", "[Anthropic](https://www.anthropic.com/),\n", "an AI safety and research company.\n", "The figures below are from that blog post.\n", "It is the spiritual successor to the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "covered in\n", "[Lab 02b](https://lab02b-colab).\n", "If you want to truly understand Transformers,\n", "we highly recommend you check it out,\n", "including the\n", "[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "UUbNVvM5Ferm" }, "source": [ "It's easy to see that ResNets are residual --\n", "it's in the name, after all.\n", "\n", "But Transformers are,\n", "in some sense,\n", "even more closely tied to residual computation\n", "than are ResNets:\n", "ResNets and related architectures include downsampling,\n", "so there is not a direct path from inputs to outputs.\n", "\n", "In Transformers, the exact same shape is maintained\n", "from the moment tokens are embedded,\n", "through dozens or hundreds of intermediate layers,\n", "and until they are \"unembedded\" into class logits.\n", "The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n", "\n", "The resiudal stream is easy to see with a change of perspective.\n", "Instead of the usual architecture diagram above,\n", "which emphasizes the layers acting on the tensors,\n", "consider this alternative view,\n", "which emphasizes the tensors as they pass through the layers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HRMlVguKKW6y" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-residual-view.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9K3N7ilVkB3" }, "source": [ "For definitions of variables and terms, see the\n", "[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)." ] }, { "cell_type": "markdown", "metadata": { "id": "arvciE-kKd_L" }, "source": [ "Note that this is a _decoder-only_ Transformer architecture --\n", "so it should be compared with the right-hand side of the original architecture diagram above." ] }, { "cell_type": "markdown", "metadata": { "id": "wvrRMd_RKp_G" }, "source": [ "Notice that outputs of the attention blocks \n", "and of the MLP layers are\n", "added to their inputs, as in a ResNet.\n", "These operations are represented as \"Add & Norm\" layers in the classical diagram;\n", "normalization is ignored here for simplicity." ] }, { "cell_type": "markdown", "metadata": { "id": "o8n_iT-FFAbK" }, "source": [ "This total commitment to residual operations\n", "means the size of the embeddings\n", "(referred to as the \"model dimension\" or the \"embedding dimension\",\n", "here and below `d_model`)\n", "stays the same throughout the entire network.\n", "\n", "That means, for example,\n", "that the output of each layer can be used as input to the \"unembedding\" layer\n", "that produces logits.\n", "We can read out the computations of intermediate layers\n", "just by passing them through the unembedding layer\n", "and examining the logit tensor.\n", "See\n", "[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n", "for detailed experiments and interactive notebooks.\n", "\n", "In short, we observe a sort of \"progressive refinement\"\n", "of the next-token prediction\n", "as the embeddings proceed, depthwise, through the network." ] }, { "cell_type": "markdown", "metadata": { "id": "Ovh_3YgY9z2h" }, "source": [ "### Intuition #2 Transformer heads learn low rank transformations." ] }, { "cell_type": "markdown", "metadata": { "id": "XpNmozlnOdPC" }, "source": [ "In the original paper and in\n", "most presentations of Transformers,\n", "the attention layer is written like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PA7me8gNP5LE" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In pseudo-typed PyTorch (based loosely on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n", "that looks like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Oeict_6wGJgD" }, "source": [ "```python\n", "def classic_attention(\n", " Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " K: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " return torch.softmax(Q @ K.T) @ V\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "8pewU90DSuOR" }, "source": [ "This is effectively exactly\n", "how it is written\n", "in PyTorch,\n", "apart from implementation details\n", "(look for `bmm` for the matrix multiplications and a `softmax` call):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrgTpKFvOhwc" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "F._scaled_dot_product_attention??" ] }, { "cell_type": "markdown", "metadata": { "id": "ebDXZ0tlSe7g" }, "source": [ "But the best way to write an operation so that a computer can execute it quickly\n", "is not necessarily the best way to write it so that a human can understand it --\n", "otherwise we'd all be coding in assembly.\n", "\n", "And this is a strange way to write it --\n", "you'll notice that what we normally think of\n", "as the \"inputs\" to the layer are not shown.\n", "\n", "We can instead write out the attention layer\n", "as a function of the inputs $x$.\n", "We write it for a single \"attention head\".\n", "Each attention layer includes a number of heads\n", "that read and write from the residual stream\n", "simultaneously and independently.\n", "We also add the output layer weights $W_O$\n", "and we get:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LuFNR67tQpsf" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SVnBjjfOLwxP" }, "source": [ "or, in pseudo-typed PyTorch:" ] }, { "cell_type": "markdown", "metadata": { "id": "LmpOm-HfGaNz" }, "source": [ "```python\n", "def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n", " key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n", " key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n", " # maps queries of residual stream to keys from residual stream, independent of position\n", "\n", " value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n", " output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n", " value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n", " # transformation applied to each token, regardless of position\n", "\n", " attention_logits = x.T @ key_query_circuit @ x\n", " attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n", " # maps positions to positions, often very sparse\n", "\n", " value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n", "\n", " return attention_map @ value_output # transformed tokens filtered by attention map\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "dC0eqxZ6UAGT" }, "source": [ "Consider the `key_query_circuit`\n", "and `value_output_circuit`\n", "matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n", "\n", "The key/query dimension, `d_head`\n", "is small relative to the model's dimension, `d_model`,\n", "so $W_{QK}$ and $W_{OV}$ are very low rank,\n", "[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n", "that they factorize into two matrices,\n", "one with a smaller number of rows\n", "and another with a smaller number of columns.\n", "That number is called the _rank_.\n", "\n", "When computing, these matrices are better represented via their components,\n", "rather than computed directly,\n", "which leads to the normal implementation of attention.\n", "\n", "In a large language model,\n", "the ratio of residual stream dimension, `d_model`, to\n", "the dimension of a single head, `d_head`, is huge, often 100:1.\n", "That means each query, key, and value computed at a position\n", "is a fairly simple, low-dimensional feature of the residual stream at that position.\n", "\n", "For visual intuition,\n", "we compare what a matrix with a rank 100th of full rank looks like,\n", "relative to a full rank matrix of the same size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LUbojJMiW2C" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "\n", "low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n", "full_rank = torch.randn(100, 100)\n", "plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n", "plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");" ] }, { "cell_type": "markdown", "metadata": { "id": "lqBst92-OVka" }, "source": [ "The pattern in the first matrix is very simple,\n", "relative to the pattern in the second matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "SkCGrs9EiVh4" }, "source": [ "Another feature of low rank transformations is\n", "that they have a large nullspace or kernel --\n", "these are directions we can move the input without changing the output.\n", "\n", "That means that many changes to the residual stream won't affect the behavior of this head at all." ] }, { "cell_type": "markdown", "metadata": { "id": "UVz2dQgzhD4p" }, "source": [ "### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)." ] }, { "cell_type": "markdown", "metadata": { "id": "hVlzwR03m8mC" }, "source": [ "The combination of residuality\n", "(changes are added to the current input)\n", "and low rank\n", "(only a small subspace is changed by each head)\n", "drastically changes the intuition about Transformers." ] }, { "cell_type": "markdown", "metadata": { "id": "qqjZI2jKe6HH" }, "source": [ "Rather than being an \"embedding of a token in its context\",\n", "the residual stream becomes something more like a memory or a scratchpad:\n", "one layer reads a small bit of information from the stream\n", "and writes a small bit of information back to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YIBkxlqepjc" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-layer-residual.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtsKhkLfk00l" }, "source": [ "The residual stream works like a memory because it is roomy enough\n", "that these actions need not interfere:\n", "the subspaces targeted by reads and writes are small relative to the ambient space,\n", "so they can\n", "\n", "Additionally, the dimension of each head is still in the 100s in large models,\n", "and\n", "[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n", "in them, so the number of effectively degrees of freedom is\n", "actually larger than the dimension.\n", "This phenomenon allows high-dimensional tensors to serve as\n", "[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n", "There are\n", "[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n", "\n", "Together, this means an early layer can write information to the stream\n", "that can be used by later layers -- by many of them at once, possibly much later.\n", "Later layers can learn to edit this information,\n", "e.g. deleting it,\n", "if doing so reduces the loss,\n", "but by default the information is preserved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EragIygzJg86" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-stream-read-write.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "oKIaUZjwkpW7" }, "source": [ "Lastly, the softmax in the attention has a sparsifying effect,\n", "and so many attention heads are reading from \n", "just one token and writing to just one other token." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dN6VcJqIMKnB" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-token-to-token.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Repeatedly reading information from an external memory\n", "and using it to decide which operation to perform\n", "and where to write the results\n", "is at the core of the\n", "[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n", "For a concrete example, the\n", "[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n", "includes a dissection of a form of \"pointer arithmetic\"\n", "that appears in some models." ] }, { "cell_type": "markdown", "metadata": { "id": "0kLFh7Mvnolr" }, "source": [ "This point of view seems\n", "very promising for explaining numerous\n", "otherwise perhaps counterintuitive features of Transformer models.\n", "\n", "- This framework predicts lots that Transformers will readily copy-and-paste information,\n", "which might explain phenomena like\n", "[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n", "\n", "- It also readily explains\n", "[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n", "an important component of why Transformers perform well on medium-length texts\n", "and in few-shot learning.\n", "\n", "- Transformers also perform better on reasoning tasks when the text\n", "[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n", "is added to their input prompt.\n", "This is partly due to the fact that that prompt is associated,\n", "in the dataset, with clearer reasoning,\n", "and since the models are trained to predict which tokens tend to appear\n", "after an input, they tend to produce better reasoning with that prompt --\n", "an explanation purely in terms of sequence modeling.\n", "But it also gives the Transformer license to generate a large number of tokens\n", "that act to store intermediate information,\n", "making for a richer residual stream\n", "for reading and writing." ] }, { "cell_type": "markdown", "metadata": { "id": "RyLRzgG-93yB" }, "source": [ "### Implementation detail: Transformers are position-insensitive by default." ] }, { "cell_type": "markdown", "metadata": { "id": "oR6PnrlA_hJ2" }, "source": [ "In the attention calculation\n", "each token can query each other token,\n", "with no regard for order.\n", "Furthermore, the construction of queries, keys, and values\n", "is based on the content of the embedding vector,\n", "which does not automatically include its position.\n", "\"dog bites man\" and \"man bites dog\" are identical, as in\n", "[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n", "\n", "For most sequences,\n", "this is unacceptable:\n", "absolute and relative position matter\n", "and we cannot use the future to predict the past.\n", "\n", "We need to add two pieces to get a Transformer architecture that's usable for next-token prediction." ] }, { "cell_type": "markdown", "metadata": { "id": "EWHxGJz2-6ZK" }, "source": [ "First, the simpler piece:\n", "\"causal\" attention,\n", "so-named because it ensures that values earlier in the sequence\n", "are not influenced by later values, which would\n", "[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)." ] }, { "cell_type": "markdown", "metadata": { "id": "0c42xi6URYB4" }, "source": [ "The most common solution is straightforward:\n", "we calculate attention between all tokens,\n", "then throw out non-causal values by \"masking\" them\n", "(this is before applying the softmax,\n", "so masking means adding $-\\infty$).\n", "\n", "This feels wasteful --\n", "why are we calculating values we don't need?\n", "Trying to be smarter would be harder,\n", "and might rely on operations that aren't as optimized as\n", "matrix multiplication and addition.\n", "Furthermore, it's \"only\" twice as many operations,\n", "so it doesn't even show up in $O$-notation.\n", "\n", "A sample attention mask generated by our code base is shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NXaWe6pT-9jV" }, "outputs": [], "source": [ "from text_recognizer.models import transformer_util\n", "\n", "\n", "attention_mask = transformer_util.generate_square_subsequent_mask(100)\n", "\n", "ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n", "plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n", "print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This solves our causality problem,\n", "but we still don't have positional information." ] }, { "cell_type": "markdown", "metadata": { "id": "ZamUE4WIoGS2" }, "source": [ "The standard technique\n", "is to add alternating sines and cosines\n", "of increasing frequency to the embeddings\n", "(there are\n", "[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n", "most notably\n", "[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n", "Each position in the sequence is then uniquely identifiable\n", "from the pattern of these values.\n", "\n", "> Furthermore, for the same reason that\n", " [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n", " translations, e.g. relative positions, are fairly easy to express as linear transformations\n", " of sines and cosines)." ] }, { "cell_type": "markdown", "metadata": { "id": "IDG2uOsaELU0" }, "source": [ "We superimpose this positional information on our embeddings.\n", "Note that because the model is residual,\n", "this position information will be by default preserved\n", "as it passes through the network,\n", "so it doesn't need to be repeatedly added." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what this positional encoding looks like in our codebase:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Zk62Q-a-1Ax" }, "outputs": [], "source": [ "PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n", "\n", "pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n", "\n", "ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n", "print(pe[:4, :8])" ] }, { "cell_type": "markdown", "metadata": { "id": "ep2ClIWvqDms" }, "source": [ "When we add the positional information to our embeddings,\n", "both the embedding information and the positional information\n", "is approximately preserved,\n", "as can be visually assessed below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PJuFjoCzC0Y4" }, "outputs": [], "source": [ "fake_embeddings = torch.randn_like(pe) * 0.5\n", "\n", "ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n", "\n", "fake_embeddings_with_pe = fake_embeddings + pe\n", "\n", "plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIzBxDkEmH8" }, "source": [ "A [similar technique](https://arxiv.org/abs/2103.06450)\n", "is used to also incorporate positional information into the image embeddings,\n", "which are flattened before being fed to the decoder." ] }, { "cell_type": "markdown", "metadata": { "id": "HC1N85wl8dvn" }, "source": [ "### Learn more about Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "lJwYxkjTk15t" }, "source": [ "We're only able to give a flavor and an intuition for Transformers here.\n", "\n", "To improve your grasp on the nuts and bolts, check out the\n", "[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n", "which is surprisingly approachable,\n", "as far as ML research papers go.\n", "The\n", "[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n", "adds code and commentary to the original paper,\n", "which makes it even more digestible.\n", "For something even friendlier, check out the\n", "[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n", "by Jay Alammar, which has an accompanying\n", "[video](https://youtu.be/-QH8fRhqFHM).\n", "\n", "Anthropic's work on\n", "[Transformer Circuits](https://transformer-circuits.pub/),\n", "summarized above, has some of the best material\n", "for building theoretical understanding\n", "and is still being updated with extensions and applications of the framework.\n", "The\n", "[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n", "are a great aid for checking and building your understanding.\n", "\n", "But they are fairly math-heavy.\n", "If you have more of a software engineering background, see\n", "Transformer Circuits co-author Nelson Elhage's blog post\n", "[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n", "\n", "For a gentler introduction to the intuition for Transformers,\n", "check out Brandon Rohrer's\n", "[Transformers From Scratch](https://e2eml.school/transformers.html)\n", "tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "qg7zntJES-aT" }, "source": [ "An aside:\n", "the matrix multiplications inside attention dominate\n", "the big-$O$ runtime of Transformers.\n", "So trying to make the attention mechanism more efficient, e.g. linear time,\n", "has generated a lot of research\n", "(review paper\n", "[here](https://arxiv.org/abs/2009.06732)).\n", "Despite drawing a lot of attention, so to speak,\n", "at the time of writing in mid-2022, these methods\n", "[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n", "so it isn't likely to be worth the effort to spend time learning about them\n", "unless you are a Transformer specialist." ] }, { "cell_type": "markdown", "metadata": { "id": "vCjXysEJ8g9_" }, "source": [ "# Using Transformers to read paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "KsfKWnOvqjva" }, "source": [ "Our simple convolutional model for text recognition from\n", "[Lab 02b](https://fsdl.me/lab02b-colab)\n", "could only handle cleanly-separated characters.\n", "\n", "It worked by sliding a LeNet-style CNN\n", "over the image,\n", "predicting a character for each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njLdzBqy-I90" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist_lines = text_recognizer.data.EMNISTLines()\n", "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "\n", "# for sliding, see the for loop over range(S)\n", "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "K0N6yDBQq8ns" }, "source": [ "But unfortunately for us, handwritten text\n", "doesn't come in neatly-separated characters\n", "of equal size, so we trained our model on synthetic data\n", "designed to work with that model." ] }, { "cell_type": "markdown", "metadata": { "id": "hiqUVbj0sxLr" }, "source": [ "Now that we have a better model,\n", "we can work with better data:\n", "paragraphs from the\n", "[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)." ] }, { "cell_type": "markdown", "metadata": { "id": "oizsOAcKs-dD" }, "source": [ "The cell uses our `LightningDataModule`\n", "to download and preprocess this data,\n", "writing results to disk.\n", "We can then spin up `DataLoader`s to give us batches.\n", "\n", "It can take several minutes to run the first time\n", "on commodity machines,\n", "with most time spent extracting the data.\n", "On subsequent runs,\n", "the time-consuming operations will not be repeated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uL9LHbjdsUbm" }, "outputs": [], "source": [ "iam_paragraphs = text_recognizer.data.IAMParagraphs()\n", "\n", "iam_paragraphs.prepare_data()\n", "iam_paragraphs.setup()\n", "xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n", "\n", "iam_paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "nBkFN9bbTm_S" }, "source": [ "Now that we've got a batch,\n", "let's take a look at some samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hqaps8yxtBhU" }, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import wandb\n", "\n", "\n", "def show(y):\n", " y = y.detach().cpu() # bring back from accelerator if it's being used\n", " return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"

\", \"\")\n", "\n", "idx = random.randint(0, len(xs))\n", "\n", "print(show(ys[idx]))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "4dT3UCNzTsoc" }, "source": [ "The `ResnetTransformer` model can run on this data\n", "if passed the `.config`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXL-vIGRr86D" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())" ] }, { "cell_type": "markdown", "metadata": { "id": "MMxa-oWyT01E" }, "source": [ "Our models are now big enough\n", "that we want to make use of GPU acceleration\n", "as much as we can,\n", "even when working on single inputs,\n", "so let's cast to the GPU if we have one." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YyUM8LgvW0w" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "rnt.to(device); xs = xs.to(device); ys = ys.to(device);" ] }, { "cell_type": "markdown", "metadata": { "id": "Y-E3UdD4zUJi" }, "source": [ "First, let's just pass it through the ResNet encoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-LUUtlvaxrvg" }, "outputs": [], "source": [ "resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n", " # resnet is designed for RGB images, so we replicate the input across channels 3 times" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eimgJ5dnywjg" }, "outputs": [], "source": [ "resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n", "plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n", "plt.axis(\"off\"); plt.colorbar(fraction=0.05);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These embeddings, though generated by random, untrained weights,\n", "are not entirely useless.\n", "\n", "Before neural networks could be effectively\n", "trained end to end,\n", "they were often used with frozen random weights\n", "eveywhere except the final layer\n", "(see e.g.\n", "[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n", "[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n", "these methods were still competitive, and\n", "[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n", "provide a\n", "[theoretical basis](https://arxiv.org/abs/2011.14522)\n", "for understanding their performance." ] }, { "cell_type": "markdown", "metadata": { "id": "ye6pW0ETzw2A" }, "source": [ "The final result, though, is repetitive gibberish --\n", "at the bare minimum, we need to train the unembedding/readout layer\n", "in order to get reasonable text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our architecture includes randomization with dropout,\n", "so repeated runs of the cell below will generate different outcomes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu3Pa7gLsFMo" }, "outputs": [], "source": [ "preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gvCXUbskv6XM" }, "outputs": [], "source": [ "print(show(preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without teacher forcing, runtime is also variable from iteration to iteration --\n", "the model stops when it generates an \"end sequence\" or padding token,\n", "which is not deterministic thanks to the dropout layers.\n", "For similar reasons, runtime is variable across inputs.\n", "\n", "The variable runtime of autoregressive generation\n", "is also not great for scaling.\n", "In a distributed setting, as required for large scale,\n", "forward passes need to be synced across devices,\n", "and if one device is generating a batch of much longer sequences,\n", "it will cause all the others to idle while they wait on it to finish." ] }, { "cell_type": "markdown", "metadata": { "id": "t76MSVRXV0V7" }, "source": [ "Let's turn our model into a `TransformerLitModel`\n", "so we can run with teacher forcing.\n", "\n", "> You may be wondering:\n", " why isn't teacher forcing part of the PyTorch module?\n", " In general, the `LightningModule`\n", " should encapsulate things that are needed in training, validation, and testing\n", " but not during inference.\n", " The teacher forcing trick fits this paradigm,\n", " even though it's so critical to what makes Transformers powerful. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8qrHRKHowdDi" }, "outputs": [], "source": [ "import text_recognizer.lit_models\n", "\n", "lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlNaFqR50Oid" }, "source": [ "Now we can use `.teacher_forward` if we also provide the target `ys`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpZdqXS5wn0F" }, "outputs": [], "source": [ "forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Zx9SmsN0QLT" }, "source": [ "This may not run faster than the `rnt.forward`,\n", "since generations are always the maximum possible length,\n", "but runtimes and output lengths are deterministic and constant." ] }, { "cell_type": "markdown", "metadata": { "id": "tu-XNYpi0Qvi" }, "source": [ "Forcing doesn't necessarily make our predictions better.\n", "They remain highly repetitive gibberish." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JcEgify9w0sv" }, "outputs": [], "source": [ "forcing_preds = torch.argmax(forcing_outs, dim=0)\n", "\n", "print(show(forcing_preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xn6GGNzc9a3o" }, "source": [ "## Training the `ResNetTransformer`" ] }, { "cell_type": "markdown", "metadata": { "id": "uvZYsuSyWUXe" }, "source": [ "We're finally ready to train this model on full paragraphs of handwritten text!" ] }, { "cell_type": "markdown", "metadata": { "id": "3cJwC7b720Sd" }, "source": [ "This is a more serious model --\n", "it's the one we use in the\n", "[deployed TextRecognizer application](http://fsdl.me/app).\n", "It's much larger than the models we've seen this far,\n", "so it can easily outstrip available compute resources,\n", "in particular GPU memory.\n", "\n", "To help, we use\n", "[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n", "which shrinks the size of most of our floats by half,\n", "which reduces memory consumption and can speed up computation.\n", "\n", "If your GPU has less than 8GB of available RAM,\n", "you'll see a \"CUDA out of memory\" `RuntimeError`,\n", "which is something of a\n", "[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n", "In this case, you can resolve it by reducing the `--batch_size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w1mXlhfy04Nm" }, "outputs": [], "source": [ "import torch\n", "\n", "gpus = int(torch.cuda.is_available())\n", "\n", "if gpus:\n", " !nvidia-smi\n", "else:\n", " print(\"watch out! working with this model on a typical CPU is not feasible\")" ] }, { "cell_type": "markdown", "metadata": { "id": "os1vW1rPZ1dy" }, "source": [ "Even with an okay GPU, like a\n", "[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n", "a single epoch of training can take over 10 minutes to run.\n", "We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n", "but you can remove those flags to see what full training looks like." ] }, { "cell_type": "markdown", "metadata": { "id": "vnF6dWFn4JlZ" }, "source": [ "It can take a long time (overnight)\n", "to train this model to decent performance on a single GPU,\n", "so we'll focus on other pieces for the exercises.\n", "\n", "> At the time of writing in mid-2022, the cheapest readily available option\n", "for training this model to decent performance on this dataset with this codebase\n", "comes out around $10, using\n", "[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n", "See, for example,\n", "[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n", "and associated experiment.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HufjdUZN0t4l", "scrolled": false }, "outputs": [], "source": [ "%%time\n", "# above %%magic times the cell, useful as a poor man's profiler\n", "\n", "%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n", " --gpus={gpus} --batch_size 16 --precision 16 \\\n", " --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2" ] }, { "cell_type": "markdown", "metadata": { "id": "L6fQ93ju3Iku" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "udb1Ekjx3L63" }, "source": [ "### 🌟 Try out gradient accumulation and other \"training tricks\"." ] }, { "cell_type": "markdown", "metadata": { "id": "kpqViB4p3Wfb" }, "source": [ "Larger batches are helpful not only for increasing parallelization\n", "and amortizing fixed costs\n", "but also for getting more reliable gradients.\n", "Larger batches give gradients with less noise\n", "and to a point, less gradient noise means faster convergence.\n", "\n", "But larger batches result in larger tensors,\n", "which take up more GPU memory,\n", "a resource that is tightly constrained\n", "and device-dependent.\n", "\n", "Does that mean we are limited in the quality of our gradients\n", "due to our machine size?\n", "\n", "Not entirely:\n", "look up the `--accumulate_grad_batches`\n", "argument to the `pl.Trainer`.\n", "You should be able to understand why\n", "it makes it possible to compute the same gradients\n", "you would find for a batch of size `k * N`\n", "on a machine that can only run batches up to size `N`.\n", "\n", "Accumulating gradients across batches is among the\n", "[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n", "Try some of them out!\n", "Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment." ] }, { "cell_type": "markdown", "metadata": { "id": "b2vtkmX830y3" }, "source": [ "### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n", "\n", "While training this model to actually fit the whole dataset is infeasible\n", "as a short exercise on commodity hardware,\n", "it's practical to train this model to memorize a batch of 16 examples.\n", "\n", "Passing `--overfit_batches 1` flag limits the number of training batches to 1\n", "and turns off\n", "[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n", "so that in each epoch, the model just sees the same single batch of data over and over again.\n", "\n", "At first, try training the model to a loss of `2.5` --\n", "it should be doable in 100 epochs or less,\n", "which is just a few minutes on a commodity GPU.\n", "\n", "Once you've got that working,\n", "crank up the number of epochs by a factor of 10\n", "and confirm that the loss continues to go down.\n", "\n", "Some tips:\n", "\n", "- Use `--limit_test_batches 0` to turn off testing.\n", "We don't need it because we don't care about generalization\n", "and it's relatively slow because it runs the model autoregressively.\n", "\n", "- Use `--help` and look through the model class args\n", "to find the arguments used to reduce model size.\n", "\n", "- By default, there's lots of regularization to prevent overfitting.\n", "Look through the args for the model class and data class\n", "for regularization knobs to turn off or down." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab03_transformers.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab03/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab03/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines from .iam_paragraphs import IAMParagraphs ================================================ FILE: lab03/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab03/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab03/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab03/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab03/text_recognizer/data/iam.py ================================================ """Class for loading the IAM handwritten text dataset, which encompasses both paragraphs and lines, plus utilities.""" from pathlib import Path from typing import Any, cast, Dict, List, Optional import zipfile from boltons.cacheutils import cachedproperty from defusedxml import ElementTree from PIL import Image, ImageOps import toml from text_recognizer import util from text_recognizer.data.base_data_module import _download_raw_dataset, load_and_print_info import text_recognizer.metadata.iam as metadata from text_recognizer.metadata.iam_paragraphs import NEW_LINE_TOKEN METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME EXTRACTED_DATASET_DIRNAME = metadata.EXTRACTED_DATASET_DIRNAME class IAM: """A dataset of images of handwritten text written on a form underneath a typewritten prompt. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database Images are identified by their "form ID". These IDs are used to separate train, validation and test splits, as keys for dictonaries returning label and image crop region data, and more. The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (LWITLRT): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self): self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self): if self.xml_filenames: return filename = _download_raw_dataset(self.metadata, DL_DATA_DIRNAME) # type: ignore _extract_raw_dataset(filename, DL_DATA_DIRNAME) def load_image(self, id: str) -> Image.Image: """Load and return an image of an entire IAM form. The image is grayscale with white text on black background. This image will have the printed prompt text at the top, above the handwritten text. Images of individual words or lines and of whole paragraphs can be cropped out using the relevant crop region data. """ image = util.read_image_pil(self.form_filenames_by_id[id], grayscale=True) image = ImageOps.invert(image) return image def __repr__(self): """Print info about the dataset.""" info = ["IAM Dataset"] info.append(f"Total Images: {len(self.xml_filenames)}") info.append(f"Total Test Images: {len(self.test_ids)}") info.append(f"Total Paragraphs: {len(self.paragraph_string_by_id)}") num_lines = sum(len(line_regions) for line_regions in self.line_regions_by_id.items()) info.append(f"Total Lines: {num_lines}") return "\n\t".join(info) @cachedproperty def all_ids(self): """A list of all form IDs.""" return sorted([f.stem for f in self.xml_filenames]) @cachedproperty def ids_by_split(self): return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids} @cachedproperty def split_by_id(self): """A dictionary mapping form IDs to their split according to IAM Lines LWITLRT.""" split_by_id = {id_: "train" for id_ in self.train_ids} split_by_id.update({id_: "val" for id_ in self.validation_ids}) split_by_id.update({id_: "test" for id_ in self.test_ids}) return split_by_id @cachedproperty def train_ids(self): """A list of form IDs which are in the IAM Lines LWITLRT training set.""" return list(set(self.all_ids) - (set(self.test_ids) | set(self.validation_ids))) @cachedproperty def test_ids(self): """A list of form IDs from the IAM Lines LWITLRT test set.""" return _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/testset.txt") @property def xml_filenames(self) -> List[Path]: """A list of the filenames of all .xml files, which contain label information.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @cachedproperty def validation_ids(self): """A list of form IDs from IAM Lines LWITLRT validation sets 1 and 2.""" val_ids = _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset1.txt") val_ids.extend(_get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset2.txt")) return val_ids @property def form_filenames(self) -> List[Path]: """A list of the filenames of all .jpg files, which contain images of IAM forms.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def xml_filenames_by_id(self): """A dictionary mapping form IDs to their XML label information files.""" return {filename.stem: filename for filename in self.xml_filenames} @property def form_filenames_by_id(self): """A dictionary mapping form IDs to their JPEG images.""" return {filename.stem: filename for filename in self.form_filenames} @cachedproperty def line_strings_by_id(self): """A dict mapping an IAM form id to its list of line texts.""" return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def line_regions_by_id(self): """A dict mapping an IAM form id to its list of line image crop regions.""" return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def paragraph_string_by_id(self): """A dict mapping an IAM form id to its paragraph text.""" return {id: NEW_LINE_TOKEN.join(line_strings) for id, line_strings in self.line_strings_by_id.items()} @cachedproperty def paragraph_region_by_id(self): """A dict mapping an IAM form id to its paragraph image crop region.""" return { id: { "x1": min(region["x1"] for region in line_regions), "y1": min(region["y1"] for region in line_regions), "x2": max(region["x2"] for region in line_regions), "y2": max(region["y2"] for region in line_regions), } for id, line_regions in self.line_regions_by_id.items() } def _extract_raw_dataset(filename: Path, dirname: Path) -> None: print("Extracting IAM data") with util.temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zip_file: zip_file.extractall() def _get_ids_from_lwitlrt_split_file(filename: str) -> List[str]: """Get the ids from Large Writer Independent Text Line Recognition Task (LWITLRT) data split file.""" with open(filename, "r") as f: line_ids_str = f.read() line_ids = line_ids_str.split("\n") page_ids = list({"-".join(line_id.split("-")[:2]) for line_id in line_ids if line_id}) return page_ids def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace " with ".""" xml_line_elements = _get_line_elements_from_xml_file(filename) return [_get_text_from_xml_element(el) for el in xml_line_elements] def _get_text_from_xml_element(xml_element: Any) -> str: """Extract text from any XML element.""" return xml_element.attrib["text"].replace(""", '"') def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get the line region dict for each line.""" xml_line_elements = _get_line_elements_from_xml_file(filename) line_regions = [ cast(Dict[str, int], _get_region_from_xml_element(xml_elem=el, xml_path="word/cmp")) for el in xml_line_elements ] assert any(region is not None for region in line_regions), "Line regions cannot be None" # next_line_region["y1"] - prev_line_region["y2"] can be negative due to overlapping characters line_gaps_y = [ max(next_line_region["y1"] - prev_line_region["y2"], 0) for next_line_region, prev_line_region in zip(line_regions[1:], line_regions[:-1]) ] post_line_gaps_y = line_gaps_y + [2 * metadata.LINE_REGION_PADDING] pre_line_gaps_y = [2 * metadata.LINE_REGION_PADDING] + line_gaps_y return [ { "x1": region["x1"] - metadata.LINE_REGION_PADDING, "x2": region["x2"] + metadata.LINE_REGION_PADDING, "y1": region["y1"] - min(metadata.LINE_REGION_PADDING, pre_line_gaps_y[i] // 2), "y2": region["y2"] + min(metadata.LINE_REGION_PADDING, post_line_gaps_y[i] // 2), } for i, region in enumerate(line_regions) ] def _get_line_elements_from_xml_file(filename: str) -> List[Any]: """Get all line xml elements from xml file.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec return xml_root_element.findall("handwritten-part/line") def _get_region_from_xml_element(xml_elem: Any, xml_path: str) -> Optional[Dict[str, int]]: """ Get region from input xml element. The region is downsampled because the stored images are also downsampled. Parameters ---------- xml_elem xml element can be a line or word element with x, y, width, and height attributes xml_path should be "word/cmp" if xml_elem is a line element, else "cmp" """ unit_elements = xml_elem.findall(xml_path) if not unit_elements: return None return { "x1": min(int(el.attrib["x"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y1": min(int(el.attrib["y"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "x2": max(int(el.attrib["x"]) + int(el.attrib["width"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y2": max(int(el.attrib["y"]) + int(el.attrib["height"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, } if __name__ == "__main__": load_and_print_info(IAM) ================================================ FILE: lab03/text_recognizer/data/iam_paragraphs.py ================================================ """IAM Paragraphs Dataset class.""" import argparse import json from pathlib import Path from typing import Callable, Dict, Optional, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.paragraph import ParagraphStem IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME class IAMParagraphs(BaseDataModule): """IAM Handwriting database paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true").lower() == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = ParagraphStem() self.trainval_transform = ParagraphStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): return rank_zero_info( "IAMParagraphs.prepare_data: Cropping IAM paragraph regions and saving them along with labels..." ) iam = IAM() iam.prepare_data() properties = {} for split in ["train", "val", "test"]: crops, labels = get_paragraph_crops_and_labels(iam=iam, split=split) save_crops_and_labels(crops=crops, labels=labels, split=split) properties.update( { id_: { "crop_shape": crops[id_].size[::-1], "label_length": len(label), "num_lines": _num_lines(label), } for id_, label in labels.items() } ) with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: def _load_dataset(split: str, transform: Callable) -> BaseDataset: crops, labels = load_processed_crops_and_labels(split) Y = convert_strings_to_labels(strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]) return BaseDataset(crops, Y, transform=transform) rank_zero_info(f"IAMParagraphs.setup({stage}): Loading IAM paragraph regions and lines...") validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) if stage == "fit" or stage is None: self.data_train = _load_dataset(split="train", transform=self.trainval_transform) self.data_val = _load_dataset(split="val", transform=self.transform) if stage == "test" or stage is None: self.data_test = _load_dataset(split="test", transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def validate_input_and_output_dimensions( input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] ) -> None: """Validate input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] # Add 2 because of start and end tokens assert output_dims is not None and output_dims[0] >= properties["label_length"]["max"] + 2 def get_paragraph_crops_and_labels( iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR ) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: """Create IAM paragraph crops and labels for a given split, with resizing.""" crops = {} labels = {} for iam_id in iam.ids_by_split[split]: image = iam.load_image(iam_id) para_region = iam.paragraph_region_by_id[iam_id] crops[iam_id] = image.crop([para_region[_] for _ in ["x1", "y1", "x2", "y2"]]) crops[iam_id] = resize_image(crops[iam_id], scale_factor=scale_factor) labels[iam_id] = iam.paragraph_string_by_id[iam_id] assert len(crops) == len(labels) return crops, labels def save_crops_and_labels(crops: Dict[str, Image.Image], labels: Dict[str, str], split: str): """Save crops, labels and shapes of crops of a split.""" (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with open(_labels_filename(split), "w") as f: json.dump(labels, f, indent=4) for id_, crop in crops.items(): crop.save(_crop_filename(id_, split)) def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: """Load processed crops and labels for given split.""" with open(_labels_filename(split), "r") as f: labels = json.load(f) sorted_ids = sorted(labels.keys()) ordered_crops = [Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids] ordered_labels = [labels[id_] for id_ in sorted_ids] assert len(ordered_crops) == len(ordered_labels) return ordered_crops, ordered_labels def get_dataset_properties() -> dict: """Return properties describing the overall dataset.""" with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: properties = json.load(f) def _get_property_values(key: str) -> list: return [_[key] for _ in properties.values()] crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count(NEW_LINE_TOKEN) + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs) ================================================ FILE: lab03/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab03/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab03/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab03/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel from .transformer import TransformerLitModel ================================================ FILE: lab03/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy from .metrics import CharacterErrorRate OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) class BaseImageToTextLitModel(BaseLitModel): # pylint: disable=too-many-ancestors """Base class for ImageToText models in PyTorch Lightning.""" def __init__(self, model, args: argparse.Namespace = None): super().__init__(model, args) self.model = model self.args = vars(args) if args is not None else {} self.inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} self.start_index = self.inverse_mapping[""] self.end_index = self.inverse_mapping[""] self.padding_index = self.inverse_mapping["

"] self.ignore_tokens = [self.start_index, self.end_index, self.padding_index] self.val_cer = CharacterErrorRate(self.ignore_tokens) self.test_cer = CharacterErrorRate(self.ignore_tokens) ================================================ FILE: lab03/text_recognizer/lit_models/metrics.py ================================================ """Special-purpose metrics for tracking our model performance.""" from typing import Sequence import torch import torchmetrics class CharacterErrorRate(torchmetrics.CharErrorRate): """Character error rate metric, allowing for tokens to be ignored.""" def __init__(self, ignore_tokens: Sequence[int], *args): super().__init__(*args) self.ignore_tokens = set(ignore_tokens) def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()] targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()] super().update(preds_l, targets_l) def test_character_error_rate(): metric = CharacterErrorRate([0, 1]) X = torch.tensor( [ [0, 2, 2, 3, 3, 1], # error will be 0 [0, 2, 1, 1, 1, 1], # error will be .75 [0, 2, 2, 4, 4, 1], # error will be .5 ] ) Y = torch.tensor( [ [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], ] ) metric(X, Y) assert metric.compute() == sum([0, 0.75, 0.5]) / 3 if __name__ == "__main__": test_character_error_rate() ================================================ FILE: lab03/text_recognizer/lit_models/transformer.py ================================================ """An encoder-decoder Transformer model""" from typing import List, Sequence import torch from .base import BaseImageToTextLitModel from .util import replace_after class TransformerLitModel(BaseImageToTextLitModel): """ Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module. The module must implement an encode and decode method, and the forward method should be the forward pass during production inference. """ def __init__(self, model, args=None): super().__init__(model, args) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index) def forward(self, x): return self.model(x) def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Uses provided sequence y as guide for non-autoregressive encoding-decoding of x. Parameters ---------- x Batch of images to be encoded. See self.model.encode for shape information. y Batch of ground truth output sequences. Returns ------- torch.Tensor (B, C, Sy) logits """ x = self.model.encode(x) output = self.model.decode(x, y) # (Sy, B, C) return output.permute(1, 2, 0) # (B, C, Sy) def training_step(self, batch, batch_idx): x, y = batch logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("train/loss", loss) outputs = {"loss": loss} return outputs def validation_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True) return outputs def test_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("test/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True) return outputs def map(self, ks: Sequence[int], ignore: bool = True) -> str: """Maps an iterable of integers to a string using the lit model's mapping.""" if ignore: return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens]) else: return "".join([self.mapping[k] for k in ks]) def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]: """Maps a list of lists of integers to a list of strings using the lit model's mapping.""" return [self.map(k, ignore) for k in ks] def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor: """Converts logit-like Tensors into prediction indices, optionally overwritten after end token index. Parameters ---------- logitlikes (B, C, Sy) Tensor with classes as second dimension. The largest value is the one whose index we will return. Logits, logprobs, and probs are all acceptable. replace_after_end Whether to replace values after the first appearance of the end token with the padding token. Returns ------- torch.Tensor (B, Sy) Tensor of integers in [0, C-1] representing predictions. """ raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy) if replace_after_end: return replace_after(raw, self.end_index, self.padding_index) # (B, Sy) else: return raw # (B, Sy) ================================================ FILE: lab03/text_recognizer/lit_models/util.py ================================================ from typing import Union import torch def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: """Return indices of first appearance of element in x, collapsing along dim. Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 Parameters ---------- x One or two-dimensional Tensor to search for element. element Item to search for inside x. dim Dimension of Tensor to collapse over. Returns ------- torch.Tensor Indices where element occurs in x. If element is not found, return length of x along dim. One dimension smaller than x. Raises ------ ValueError if x is not a 1 or 2 dimensional Tensor Examples -------- >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3) tensor([2, 1, 3, 0]) >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0) tensor(0) """ if x.dim() > 2 or x.dim() == 0: raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}") matches = x == element first_appearance_mask = (matches.cumsum(dim) == 1) & matches does_match, match_index = first_appearance_mask.max(dim) first_inds = torch.where(does_match, match_index, x.shape[dim]) return first_inds def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor: """Replace all values in each row of 2d Tensor x after the first appearance of element with replace. Parameters ---------- x Two-dimensional Tensor (shape denoted (B, S)) to replace values in. element Item to search for inside x. replace Item that replaces entries that appear after element. Returns ------- outs New Tensor of same shape as x with values after element replaced. Examples -------- >>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4) tensor([[1, 2, 3], [2, 3, 4], [1, 1, 1], [3, 4, 4]]) """ first_appearances = first_appearance(x, element, dim=1) # (B,) indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,) outs = torch.where( indices[None, :] <= first_appearances[:, None], # if index is before first appearance x, # return the value from x replace, # otherwise, return the replacement value ) return outs # (B, S) ================================================ FILE: lab03/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab03/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab03/text_recognizer/metadata/iam.py ================================================ import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates ================================================ FILE: lab03/text_recognizer/metadata/iam_paragraphs.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN] IMAGE_SCALE_FACTOR = 2 IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640 IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) MAX_LABEL_LENGTH = 682 DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) ================================================ FILE: lab03/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab03/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab03/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple from .resnet_transformer import ResnetTransformer ================================================ FILE: lab03/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab03/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab03/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab03/text_recognizer/models/resnet_transformer.py ================================================ """Model combining a ResNet with a Transformer for image-to-sequence tasks.""" import argparse import math from typing import Any, Dict import torch from torch import nn import torchvision from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage TF_DIM = 256 TF_FC_DIM = 1024 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 RESNET_DIM = 512 # hard-coded class ResnetTransformer(nn.Module): """Pass an image through a Resnet and decode the resulting embedding with a Transformer.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) self.mapping = data_config["mapping"] inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # ## Encoder part - should output vector sequence of length self.dim per sample resnet = torchvision.models.resnet18(weights=None) self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) # encoder_projection will output (B, dim, _H, _W) logits self.enc_pos_encoder = PositionalEncodingImage( d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] ) # Max (Ho, Wo) # ## Decoder part self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def forward(self, x: torch.Tensor) -> torch.Tensor: """Autoregressively produce sequences of labels from input images. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- output_tokens (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1] # Set the last output token # Early stopping of prediction loop to speed up prediction if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): break # Set all tokens after end or padding token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") if self.encoder_projection.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.encoder_projection.bias, -bound, bound) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- (Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps """ _B, C, _H, _W = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo return x def decode(self, x, y): """Decode a batch of encoded images x with guiding sequences y. During autoregressive inference, the guiding sequence will be previous predictions. During training, the guiding sequence will be the ground truth. Parameters ---------- x (Sx, B, E) images encoded as sequences of embeddings y (B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) batch of logit sequences """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.dec_pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output @staticmethod def add_to_argparse(parser): parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab03/text_recognizer/models/transformer_util.py ================================================ """Position Encoding and other utilities for Transformers.""" import math import torch from torch import Tensor import torch.nn as nn class PositionalEncodingImage(nn.Module): """ Module used to add 2-D positional encodings to the feature-map produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor: pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """pytorch.nn.module.forward""" # x.shape = (B, d_model, H, W) assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding(torch.nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None: super().__init__() self.dropout = torch.nn.Dropout(p=dropout) pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_len: int) -> torch.Tensor: pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: # x.shape = (S, B, d_model) assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) def generate_square_subsequent_mask(size: int) -> torch.Tensor: """Generate a triangular (size, size) mask.""" mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask ================================================ FILE: lab03/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab03/text_recognizer/stems/paragraph.py ================================================ """IAMParagraphs Stem class.""" import torchvision.transforms as transforms import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.image import ImageStem IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH class ParagraphStem(ImageStem): """A stem for handling images that contain a paragraph of text.""" def __init__( self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None, random_perspective_kwargs=None, gaussian_blur_kwargs=None, sharpness_kwargs=None, ): super().__init__() if not augment: self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)]) else: if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "shear": 6, "scale": (0.95, 1), "interpolation": transforms.InterpolationMode.BILINEAR, } if random_perspective_kwargs is None: random_perspective_kwargs = { "distortion_scale": 0.2, "p": 0.5, "interpolation": transforms.InterpolationMode.BILINEAR, } if gaussian_blur_kwargs is None: gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} if sharpness_kwargs is None: sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} # IMAGE_SHAPE is (576, 640) self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomCrop( size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant" ), transforms.RandomAffine(**random_affine_kwargs), transforms.RandomPerspective(**random_perspective_kwargs), transforms.GaussianBlur(**gaussian_blur_kwargs), transforms.RandomAdjustSharpness(**sharpness_kwargs), ] ) ================================================ FILE: lab03/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab03/training/__init__.py ================================================ ================================================ FILE: lab03/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.loss == "transformer": lit_model_class = lit_models.TransformerLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" if goldstar_metric == "validation/cer": filename_format += "-validation.cer={validation/cer:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab03/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: lab04/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab04/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab04/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab04/notebooks/lab03_transformers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 03: Transformers and Paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The fundamental reasons why the Transformer is such\n", "a powerful and popular architecture\n", "- Core intuitions for the behavior of Transformer architectures\n", "- How to use a convolutional encoder and a Transformer decoder to recognize\n", "entire paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 3\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Transformers?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal in building a text recognizer is to take a two-dimensional image\n", "and convert it into a one-dimensional sequence of characters\n", "from some alphabet." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional neural networks,\n", "discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n", "are great at encoding images,\n", "taking them from their raw pixel values\n", "to a more semantically meaningful numerical representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But how do we go from that to a sequence of letters?\n", "And what's especially tricky:\n", "the number of letters in an image is separable from its size.\n", "A screenshot of this document has a much higher density of letters\n", "than a close-up photograph of a piece of paper.\n", "How do we get a _variable-length_ sequence of letters,\n", "where the length need have nothing to do with the size of the input tensor?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n", "they were\n", "[originally introduced](https://arxiv.org/abs/1706.03762)\n", "for transforming one sequence into another,\n", "as in machine translation.\n", "This makes them a natural fit for processing language.\n", "\n", "But they have also found success in other domains --\n", "at the time of this writing, large transformers\n", "dominate the\n", "[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n", "that has become a de facto standard for comparing models\n", "and are finding\n", "[application in reinforcement learning](https://arxiv.org/abs/2106.01345)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we will use a Transformer as a key component of our final architecture:\n", "we will encode our input images with a CNN\n", "and then read them out into a text sequence with a Transformer.\n", "\n", "Before trying out this new model,\n", "let's first get an understanding of why the Transformer architecture\n", "has become so popular by walking through its history\n", "and then get some intuition for how it works\n", "by looking at some\n", "[recent work](https://transformer-circuits.pub/)\n", "on explaining the behavior of both toy models and state-of-the-art language models." ] }, { "cell_type": "markdown", "metadata": { "id": "kmKqjbvd-Mj3" }, "source": [ "## Why not convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "SRqkUMdM-OxU" }, "source": [ "In the ancient beforetimes (i.e. 2016),\n", "the best models for natural language processing were all\n", "_recurrent_ neural networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional networks were also occasionally used,\n", "but they suffered from a serious issue:\n", "their architectural biases don't fit text.\n", "\n", "First, _translation equivariance_ no longer holds.\n", "The beginning of a piece of text is often quite different from the middle,\n", "so the absolute position matters.\n", "\n", "Second, _locality_ is not as important in language.\n", "The name of a character that hasn't appeared in thousands of pages\n", "can become salient when someone asks, \"Whatever happened to\n", "[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n", "\n", "Consider interpreting a piece of text like the Python code below:\n", "```python\n", "def do(arg1, arg2, arg3):\n", " a = arg1 + arg2\n", " b = arg3[:3]\n", " c = a * b\n", " return c\n", "\n", "print(do(1, 1, \"ayy lmao\"))\n", "```\n", "\n", "After a `(` we expect a `)`,\n", "but possibly very long afterwards,\n", "[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n", "and similarly we expect a `]` at some point after a `[`.\n", "\n", "For translation variance, consider\n", "that we interpret `*` not by\n", "comparing it to its neighbors\n", "but by looking at `a` and `b`.\n", "We mix knowledge learned through experience\n", "with new facts learned while reading --\n", "also known as _in-context learning_.\n", "\n", "In a longer text,\n", "[e.g. the one you are reading now](./lab03_transformers.ipynb),\n", "the translation variance of text is clearer.\n", "Every lab notebook begins with the same header,\n", "setting up the environment,\n", "but that header never appears elsewhere in the notebook.\n", "Later positions need to be processed in terms of the previous entries.\n", "\n", "Unlike an image, we cannot simply rotate or translate our \"camera\"\n", "and get a new valid text.\n", "[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n", "that can be read without regard to position." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The field of formal language theory,\n", "which has deep mutual influence with computer science,\n", "gives one way of explaining the issues with convolutional networks:\n", "they can only understand languages with _finite contexts_,\n", "where all the information can be found within a finite window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The immediate solution, drawing from the connections to computer science, is\n", "[recursion](https://www.google.com/search?q=recursion).\n", "A network whose output on the final entry of the sequence is a recursive function\n", "of all the previous entries can build up knowledge\n", "as it reads the sequence and treat early entries quite differently than it does late ones." ] }, { "cell_type": "markdown", "metadata": { "id": "aa6cbTlImkEh" }, "source": [ "In pseudo-code, such a _recurrent neural network_ module might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "lKtBoPnglPrW" }, "source": [ "```python\n", "def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n", " next_inputs = input_module(xs[-1])\n", " next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n", " return output_module(next_inputs, next_hiddens)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "IbJPSMnEm516" }, "source": [ "If you've had formal computer science training,\n", "then you may be familiar with the power of recursion,\n", "e.g. the\n", "[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n", "that gave its name to the now much better-known\n", "[startup incubator](https://www.ycombinator.com/).\n", "\n", "The particular form of recursion used by\n", "recurrent neural networks implements a\n", "[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n", "\n", "> If you've know a lot of computer science,\n", "you might be concerned by this connection.\n", "What about other\n", "[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n", "Where are the neural network architectures for differentiable\n", "[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n", "Check out Graph Neural Networks,\n", "[which implement dynamic programming](https://arxiv.org/abs/2203.15544)." ] }, { "cell_type": "markdown", "metadata": { "id": "63mMTbEBpVuE" }, "source": [ "Recurrent networks are able to achieve\n", "[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n", "\n", "There are many popular recurrent architectures,\n", "from the beefy and classic\n", "[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n", "and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n", "([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n", "all of which have roughly similar capabilities but\n", "[some of which are easier to train](https://arxiv.org/abs/1611.09913)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwQHVTIslOku" }, "source": [ "In the same sense that MLPs can model \"any\" feedforward function,\n", "in principle even basic RNNs\n", "[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n", "\n", "In particular they can model any\n", "[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n", "which is a formal way of saying that they can in principle\n", "do anything a computer is capable of doing.\n", "\n", "The question is then..." ] }, { "cell_type": "markdown", "metadata": { "id": "3J8EoGN3pu7P" }, "source": [ "## Why aren't we all using RNNs?" ] }, { "cell_type": "markdown", "metadata": { "id": "TDwNWaevpt_3" }, "source": [ "The guarantees that MLPs can model any function\n", "or that RNNs can model Turing machines\n", "provide decent intuition but are not directly practically useful.\n", "Among other reasons, they don't guarantee learnability --\n", "that starting from random parameters we can find the parameters\n", "that implement a given function.\n", "The\n", "[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n", "than would seem from basic theoretical and empirical analysis.\n", "\n", "One way of understanding capacity to model language is\n", "[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n", "In this model of formal languages,\n", "Turing machines sit at the top\n", "([practically speaking](https://arxiv.org/abs/math/0209332)).\n", "\n", "With better mathematical models,\n", "RNNs and LSTMs can be shown to be\n", "[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n", "with RNNs looking more like\n", "[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n", "and LSTMs coming in\n", "[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n", "\n", "More controversially:\n", "the Chomsky hierarchy is great for understanding syntax and grammar,\n", "which makes it great for building parsers\n", "and working with formal languages,\n", "but the goal in _natural_ language processing is to understand _natural_ language.\n", "Most humans' natural language is far from strictly grammatical,\n", "but that doesn't mean it is nonsense.\n", "\n", "And to really \"understand\" language means\n", "to understand its semantic content, which is fuzzy.\n", "The most important thing for handling the fuzzy semantic content\n", "of language is not whether you can recall\n", "[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n", "but whether you can model probabilistic relationships between concepts\n", "in addition to grammar and syntax." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These both leave theoretical room for improvement over current recurrent\n", "language and sequence models.\n", "\n", "But the real cause of the rise of Transformers is that..." ] }, { "cell_type": "markdown", "metadata": { "id": "Dsu1ebvAp-3Z" }, "source": [ "## Transformers are designed to train fast at scale on contemporary hardware." ] }, { "cell_type": "markdown", "metadata": { "id": "c4abU5adsPGs" }, "source": [ "The Transformer architecture has several important features,\n", "discussed below,\n", "but one of the most important reasons why it is successful\n", "is because it can be more easily trained at scale.\n", "\n", "This scalability is the focus of the discussion in the paper\n", "that introduced the architecture,\n", "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", "and\n", "[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n", "\n", "The recursion in RNNs is inherently sequential:\n", "the dependence on the outputs from earlier in the sequence\n", "means computations within an example cannot be parallelized.\n", "\n", "So RNNs must batch across examples to scale,\n", "but as sequence length grows this hits memorybandwidth limits.\n", "Serving up large batches quickly with good randomness guarantees\n", "is also hard to optimize,\n", "especially in distributed settings.\n", "\n", "The Transformer architecture,\n", "on the other hand,\n", "can be readily parallelized within a single example sequence,\n", "in addition to parallelization across batches.\n", "This can lead to massive performance gains for a fixed scale,\n", "which means larger, higher capacity models\n", "can be trained on larger datasets." ] }, { "cell_type": "markdown", "metadata": { "id": "_Mzk2haFC_G1" }, "source": [ "How does the architecture achieve this parallelizability?\n", "\n", "Let's start with the architecture diagram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u59eu4snLQfp" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n", "\n", "display.Image(url=base_url + \"/aiayn-figure-1.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ez-XEQ7M0UlR" }, "source": [ "> To head off a bit of confusion\n", " in case you've worked with Transformer architectures before:\n", " the original \"Transformer\" is an encoder/decoder architecture.\n", " Many LLMs, like GPT models, are decoder only,\n", " because this has turned out to scale well,\n", " and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n", " it's all text anyways.\n", " We, however, will be using them across modalities,\n", " so we need an explicit encoder,\n", " as above. " ] }, { "cell_type": "markdown", "metadata": { "id": "ok4ksBi4vp89" }, "source": [ "First focusing on the encoder (left):\n", "the encoding at a given position is a function of all previous inputs.\n", "But it is not a function of the previous _encodings_:\n", "we produce the encodings \"all at once\"." ] }, { "cell_type": "markdown", "metadata": { "id": "RPN7C-_OqzHP" }, "source": [ "The decoder (right) does use previous \"outputs\" as its inputs,\n", "but those outputs are not the vectors of layer activations\n", "(aka embeddings)\n", "that are produced by the network.\n", "They are instead the processed outputs,\n", "after a `softmax` and an `argmax`.\n", "\n", "We could obtain these outputs by processing the embeddings,\n", "much like in a recurrent architecture.\n", "In fact, that is one way that Transformers are run.\n", "It's what happens in the `.forward` method\n", "of the model we'll be training for character recognition:\n", "`ResnetTransformer`." ] }, { "cell_type": "markdown", "metadata": { "id": "L5_2WMmtDnJn" }, "source": [ "Let's look at that forward method\n", "and connect it to the diagram." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FR5pk4kEyCGg" }, "outputs": [], "source": [ "from text_recognizer.models import ResnetTransformer\n", "\n", "\n", "ResnetTransformer.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "-J5UFDoPzPbq" }, "source": [ "`.encode` happens first -- that's the left side of diagram.\n", "\n", "The encoder can in principle be anything\n", "that produces a sequence of fixed-length vectors,\n", "but here it's\n", "[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n", "\n", "Then we start iterating over the sequence\n", "in the `for` loop.\n", "\n", "Focus on the first few lines of code.\n", "We apply `.decode` (right side of diagram)\n", "to the outputs so far.\n", "\n", "Once we have a new `output`, we apply `.argmax`\n", "to turn the logits into a concrete prediction of\n", "a particular token.\n", "\n", "This is added as the last output token\n", "and then the loop happens again." ] }, { "cell_type": "markdown", "metadata": { "id": "LTcy8-rV1dHr" }, "source": [ "Run this way, our model looks very much like a recurrent architecture:\n", "we call the model on its own outputs\n", "to generate the next value.\n", "These types of models are also referred to as\n", "[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n", "because we predict (as we do in _regression_)\n", "the next value based on our own (_auto_) output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But Transformers are designed to be _trained_ more scalably than RNNs,\n", "not necessarily to _run inference_ more scalably,\n", "and it's actually not the case that our model's `.forward` is called during training." ] }, { "cell_type": "markdown", "metadata": { "id": "eCxMSAWmEKBt" }, "source": [ "Let's look at what happens during training\n", "by checking the `training_step`\n", "of the `LightningModule`\n", "we use to train our Transformer models,\n", "the `TransformerLitModel`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0o7q8N7P2w4H" }, "outputs": [], "source": [ "from text_recognizer.lit_models import TransformerLitModel\n", "\n", "TransformerLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "1VgNNOjvzC4y" }, "source": [ "Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`." ] }, { "cell_type": "markdown", "metadata": { "id": "tz-6NGPR4dUr" }, "source": [ "Let's look at `.teacher_forward`,\n", "and in particular its type signature:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ILc2oWET4i2Z" }, "outputs": [], "source": [ "TransformerLitModel.teacher_forward??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`." ] }, { "cell_type": "markdown", "metadata": { "id": "lf32lpgrDb__" }, "source": [ "This is known as \"teacher forcing\".\n", "The \"teacher\" signal is \"forcing\"\n", "the model to behave as though\n", "it got the answer right.\n", "\n", "[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n", "It's more effective here\n", "because the right teaching signal\n", "for our network is the target data,\n", "which we have access to during training,\n", "whereas in an RNN the best teaching signal\n", "would be the target embedding vector,\n", "which we do not know.\n", "\n", "During inference, when we don't have access to the ground truth,\n", "we revert to the autoregressive `.forward` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This \"trick\" allows Transformer architectures to readily scale\n", "up models to the parameter counts\n", "[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)." ] }, { "cell_type": "markdown", "metadata": { "id": "BAjqpJm9uUuU" }, "source": [ "## Is there more to Transformers more than just a training trick?" ] }, { "cell_type": "markdown", "metadata": { "id": "kWCYXeHv7Qc9" }, "source": [ "[Very](https://arxiv.org/abs/2005.14165),\n", "[very](https://arxiv.org/abs/1909.08053),\n", "[very](https://arxiv.org/abs/2205.01068)\n", "large Transformer models have powered the most recent wave of exciting results in ML, like\n", "[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n", "\n", "They are also the first machine learning models to have come anywhere close to\n", "deserving the term _artificial intelligence_ --\n", "a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is surprising because the models and their training procedure are\n", "(relatively speaking)\n", "pretty _simple_,\n", "even if it doesn't feel that way on first pass." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic Transformer architecture is just a bunch of\n", "dense matrix multiplications and non-linearities --\n", "it's perhaps simpler than a convolutional architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And advances since the introduction of Transformers in 2017\n", "have not in the main been made by\n", "creating more sophisticated model architectures\n", "but by increasing the scale of the base architecture,\n", "or if anything making it simpler, as in\n", "[GPT-type models](https://arxiv.org/abs/2005.14165),\n", "which drop the encoder." ] }, { "cell_type": "markdown", "metadata": { "id": "V1HQS9ey8GMc" }, "source": [ "These models are also trained on very simple tasks:\n", "most LLMs are just trying to predict the next element in the sequence,\n", "given the previous elements --\n", "a task simple enough that Claude Shannon,\n", "father of information theory, was\n", "[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n", "\n", "These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n", "e.g. by scraping the web." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They are also trained in a simple fashion:\n", "first-order stochastic optimizers, like SGD or an\n", "[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n", "intended for the most basic of optimization problems,\n", "that scale more readily than the second-order optimizers\n", "that dominate other areas of optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "Kz9HPDoy7OAl" }, "source": [ "This is\n", "[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n", "of work in ML:\n", "simple, even seemingly wasteful,\n", "architectures that scale well and are robust\n", "to implementation details\n", "eventually outstrip more clever but\n", "also more finicky approaches that are harder to scale.\n", "This lesson has led some to declare that\n", "[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n", "in machine learning, and perhaps even in artificial intelligence." ] }, { "cell_type": "markdown", "metadata": { "id": "SdN9o2Y771YZ" }, "source": [ "> That is not to say that because the algorithms are relatively simple,\n", " training a model at this scale is _easy_ --\n", " [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n", " [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n", " [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n", " But choosing the simplest algorithm at every step makes solving the scaling problem feasible." ] }, { "cell_type": "markdown", "metadata": { "id": "baVGf6gKFOvs" }, "source": [ "The importance of scale is the key lesson from the Transformer architecture,\n", "far more than any theoretical considerations\n", "or any of the implementation details.\n", "\n", "That said, these large Transformer models are capable of\n", "impressive behaviors and understanding how they achieve them\n", "is of intellectual interest.\n", "Furthermore, like any architecture,\n", "there are common failure modes,\n", "of the model and of the modelers who use them,\n", "that need to be taken into account." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "Below, we'll cover two key intuitions about Transformers:\n", "Transformers are _residual_, like ResNets,\n", "and they compose _low rank_ sequence transformations.\n", "Together, this means they act somewhat like a computer,\n", "reading from and writing to a \"tape\" or memory\n", "with a sequence of simple instructions." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "We'll also cover a surprising implementation detail:\n", "despite being commonly used for sequence modeling,\n", "by default the architecture is _position insensitive_." ] }, { "cell_type": "markdown", "metadata": { "id": "uni0VTCr9lev" }, "source": [ "### Intuition #1: Transformers are highly residual." ] }, { "cell_type": "markdown", "metadata": { "id": "0MoBt-JLJz-d" }, "source": [ "> The discussion of these inuitions summarizes the discussion in\n", "[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n", "from\n", "[Anthropic](https://www.anthropic.com/),\n", "an AI safety and research company.\n", "The figures below are from that blog post.\n", "It is the spiritual successor to the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "covered in\n", "[Lab 02b](https://lab02b-colab).\n", "If you want to truly understand Transformers,\n", "we highly recommend you check it out,\n", "including the\n", "[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "UUbNVvM5Ferm" }, "source": [ "It's easy to see that ResNets are residual --\n", "it's in the name, after all.\n", "\n", "But Transformers are,\n", "in some sense,\n", "even more closely tied to residual computation\n", "than are ResNets:\n", "ResNets and related architectures include downsampling,\n", "so there is not a direct path from inputs to outputs.\n", "\n", "In Transformers, the exact same shape is maintained\n", "from the moment tokens are embedded,\n", "through dozens or hundreds of intermediate layers,\n", "and until they are \"unembedded\" into class logits.\n", "The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n", "\n", "The resiudal stream is easy to see with a change of perspective.\n", "Instead of the usual architecture diagram above,\n", "which emphasizes the layers acting on the tensors,\n", "consider this alternative view,\n", "which emphasizes the tensors as they pass through the layers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HRMlVguKKW6y" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-residual-view.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9K3N7ilVkB3" }, "source": [ "For definitions of variables and terms, see the\n", "[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)." ] }, { "cell_type": "markdown", "metadata": { "id": "arvciE-kKd_L" }, "source": [ "Note that this is a _decoder-only_ Transformer architecture --\n", "so it should be compared with the right-hand side of the original architecture diagram above." ] }, { "cell_type": "markdown", "metadata": { "id": "wvrRMd_RKp_G" }, "source": [ "Notice that outputs of the attention blocks \n", "and of the MLP layers are\n", "added to their inputs, as in a ResNet.\n", "These operations are represented as \"Add & Norm\" layers in the classical diagram;\n", "normalization is ignored here for simplicity." ] }, { "cell_type": "markdown", "metadata": { "id": "o8n_iT-FFAbK" }, "source": [ "This total commitment to residual operations\n", "means the size of the embeddings\n", "(referred to as the \"model dimension\" or the \"embedding dimension\",\n", "here and below `d_model`)\n", "stays the same throughout the entire network.\n", "\n", "That means, for example,\n", "that the output of each layer can be used as input to the \"unembedding\" layer\n", "that produces logits.\n", "We can read out the computations of intermediate layers\n", "just by passing them through the unembedding layer\n", "and examining the logit tensor.\n", "See\n", "[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n", "for detailed experiments and interactive notebooks.\n", "\n", "In short, we observe a sort of \"progressive refinement\"\n", "of the next-token prediction\n", "as the embeddings proceed, depthwise, through the network." ] }, { "cell_type": "markdown", "metadata": { "id": "Ovh_3YgY9z2h" }, "source": [ "### Intuition #2 Transformer heads learn low rank transformations." ] }, { "cell_type": "markdown", "metadata": { "id": "XpNmozlnOdPC" }, "source": [ "In the original paper and in\n", "most presentations of Transformers,\n", "the attention layer is written like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PA7me8gNP5LE" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In pseudo-typed PyTorch (based loosely on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n", "that looks like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Oeict_6wGJgD" }, "source": [ "```python\n", "def classic_attention(\n", " Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " K: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " return torch.softmax(Q @ K.T) @ V\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "8pewU90DSuOR" }, "source": [ "This is effectively exactly\n", "how it is written\n", "in PyTorch,\n", "apart from implementation details\n", "(look for `bmm` for the matrix multiplications and a `softmax` call):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrgTpKFvOhwc" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "F._scaled_dot_product_attention??" ] }, { "cell_type": "markdown", "metadata": { "id": "ebDXZ0tlSe7g" }, "source": [ "But the best way to write an operation so that a computer can execute it quickly\n", "is not necessarily the best way to write it so that a human can understand it --\n", "otherwise we'd all be coding in assembly.\n", "\n", "And this is a strange way to write it --\n", "you'll notice that what we normally think of\n", "as the \"inputs\" to the layer are not shown.\n", "\n", "We can instead write out the attention layer\n", "as a function of the inputs $x$.\n", "We write it for a single \"attention head\".\n", "Each attention layer includes a number of heads\n", "that read and write from the residual stream\n", "simultaneously and independently.\n", "We also add the output layer weights $W_O$\n", "and we get:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LuFNR67tQpsf" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SVnBjjfOLwxP" }, "source": [ "or, in pseudo-typed PyTorch:" ] }, { "cell_type": "markdown", "metadata": { "id": "LmpOm-HfGaNz" }, "source": [ "```python\n", "def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n", " key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n", " key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n", " # maps queries of residual stream to keys from residual stream, independent of position\n", "\n", " value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n", " output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n", " value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n", " # transformation applied to each token, regardless of position\n", "\n", " attention_logits = x.T @ key_query_circuit @ x\n", " attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n", " # maps positions to positions, often very sparse\n", "\n", " value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n", "\n", " return attention_map @ value_output # transformed tokens filtered by attention map\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "dC0eqxZ6UAGT" }, "source": [ "Consider the `key_query_circuit`\n", "and `value_output_circuit`\n", "matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n", "\n", "The key/query dimension, `d_head`\n", "is small relative to the model's dimension, `d_model`,\n", "so $W_{QK}$ and $W_{OV}$ are very low rank,\n", "[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n", "that they factorize into two matrices,\n", "one with a smaller number of rows\n", "and another with a smaller number of columns.\n", "That number is called the _rank_.\n", "\n", "When computing, these matrices are better represented via their components,\n", "rather than computed directly,\n", "which leads to the normal implementation of attention.\n", "\n", "In a large language model,\n", "the ratio of residual stream dimension, `d_model`, to\n", "the dimension of a single head, `d_head`, is huge, often 100:1.\n", "That means each query, key, and value computed at a position\n", "is a fairly simple, low-dimensional feature of the residual stream at that position.\n", "\n", "For visual intuition,\n", "we compare what a matrix with a rank 100th of full rank looks like,\n", "relative to a full rank matrix of the same size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LUbojJMiW2C" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "\n", "low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n", "full_rank = torch.randn(100, 100)\n", "plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n", "plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");" ] }, { "cell_type": "markdown", "metadata": { "id": "lqBst92-OVka" }, "source": [ "The pattern in the first matrix is very simple,\n", "relative to the pattern in the second matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "SkCGrs9EiVh4" }, "source": [ "Another feature of low rank transformations is\n", "that they have a large nullspace or kernel --\n", "these are directions we can move the input without changing the output.\n", "\n", "That means that many changes to the residual stream won't affect the behavior of this head at all." ] }, { "cell_type": "markdown", "metadata": { "id": "UVz2dQgzhD4p" }, "source": [ "### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)." ] }, { "cell_type": "markdown", "metadata": { "id": "hVlzwR03m8mC" }, "source": [ "The combination of residuality\n", "(changes are added to the current input)\n", "and low rank\n", "(only a small subspace is changed by each head)\n", "drastically changes the intuition about Transformers." ] }, { "cell_type": "markdown", "metadata": { "id": "qqjZI2jKe6HH" }, "source": [ "Rather than being an \"embedding of a token in its context\",\n", "the residual stream becomes something more like a memory or a scratchpad:\n", "one layer reads a small bit of information from the stream\n", "and writes a small bit of information back to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YIBkxlqepjc" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-layer-residual.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtsKhkLfk00l" }, "source": [ "The residual stream works like a memory because it is roomy enough\n", "that these actions need not interfere:\n", "the subspaces targeted by reads and writes are small relative to the ambient space,\n", "so they can\n", "\n", "Additionally, the dimension of each head is still in the 100s in large models,\n", "and\n", "[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n", "in them, so the number of effectively degrees of freedom is\n", "actually larger than the dimension.\n", "This phenomenon allows high-dimensional tensors to serve as\n", "[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n", "There are\n", "[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n", "\n", "Together, this means an early layer can write information to the stream\n", "that can be used by later layers -- by many of them at once, possibly much later.\n", "Later layers can learn to edit this information,\n", "e.g. deleting it,\n", "if doing so reduces the loss,\n", "but by default the information is preserved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EragIygzJg86" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-stream-read-write.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "oKIaUZjwkpW7" }, "source": [ "Lastly, the softmax in the attention has a sparsifying effect,\n", "and so many attention heads are reading from \n", "just one token and writing to just one other token." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dN6VcJqIMKnB" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-token-to-token.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Repeatedly reading information from an external memory\n", "and using it to decide which operation to perform\n", "and where to write the results\n", "is at the core of the\n", "[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n", "For a concrete example, the\n", "[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n", "includes a dissection of a form of \"pointer arithmetic\"\n", "that appears in some models." ] }, { "cell_type": "markdown", "metadata": { "id": "0kLFh7Mvnolr" }, "source": [ "This point of view seems\n", "very promising for explaining numerous\n", "otherwise perhaps counterintuitive features of Transformer models.\n", "\n", "- This framework predicts lots that Transformers will readily copy-and-paste information,\n", "which might explain phenomena like\n", "[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n", "\n", "- It also readily explains\n", "[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n", "an important component of why Transformers perform well on medium-length texts\n", "and in few-shot learning.\n", "\n", "- Transformers also perform better on reasoning tasks when the text\n", "[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n", "is added to their input prompt.\n", "This is partly due to the fact that that prompt is associated,\n", "in the dataset, with clearer reasoning,\n", "and since the models are trained to predict which tokens tend to appear\n", "after an input, they tend to produce better reasoning with that prompt --\n", "an explanation purely in terms of sequence modeling.\n", "But it also gives the Transformer license to generate a large number of tokens\n", "that act to store intermediate information,\n", "making for a richer residual stream\n", "for reading and writing." ] }, { "cell_type": "markdown", "metadata": { "id": "RyLRzgG-93yB" }, "source": [ "### Implementation detail: Transformers are position-insensitive by default." ] }, { "cell_type": "markdown", "metadata": { "id": "oR6PnrlA_hJ2" }, "source": [ "In the attention calculation\n", "each token can query each other token,\n", "with no regard for order.\n", "Furthermore, the construction of queries, keys, and values\n", "is based on the content of the embedding vector,\n", "which does not automatically include its position.\n", "\"dog bites man\" and \"man bites dog\" are identical, as in\n", "[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n", "\n", "For most sequences,\n", "this is unacceptable:\n", "absolute and relative position matter\n", "and we cannot use the future to predict the past.\n", "\n", "We need to add two pieces to get a Transformer architecture that's usable for next-token prediction." ] }, { "cell_type": "markdown", "metadata": { "id": "EWHxGJz2-6ZK" }, "source": [ "First, the simpler piece:\n", "\"causal\" attention,\n", "so-named because it ensures that values earlier in the sequence\n", "are not influenced by later values, which would\n", "[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)." ] }, { "cell_type": "markdown", "metadata": { "id": "0c42xi6URYB4" }, "source": [ "The most common solution is straightforward:\n", "we calculate attention between all tokens,\n", "then throw out non-causal values by \"masking\" them\n", "(this is before applying the softmax,\n", "so masking means adding $-\\infty$).\n", "\n", "This feels wasteful --\n", "why are we calculating values we don't need?\n", "Trying to be smarter would be harder,\n", "and might rely on operations that aren't as optimized as\n", "matrix multiplication and addition.\n", "Furthermore, it's \"only\" twice as many operations,\n", "so it doesn't even show up in $O$-notation.\n", "\n", "A sample attention mask generated by our code base is shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NXaWe6pT-9jV" }, "outputs": [], "source": [ "from text_recognizer.models import transformer_util\n", "\n", "\n", "attention_mask = transformer_util.generate_square_subsequent_mask(100)\n", "\n", "ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n", "plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n", "print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This solves our causality problem,\n", "but we still don't have positional information." ] }, { "cell_type": "markdown", "metadata": { "id": "ZamUE4WIoGS2" }, "source": [ "The standard technique\n", "is to add alternating sines and cosines\n", "of increasing frequency to the embeddings\n", "(there are\n", "[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n", "most notably\n", "[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n", "Each position in the sequence is then uniquely identifiable\n", "from the pattern of these values.\n", "\n", "> Furthermore, for the same reason that\n", " [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n", " translations, e.g. relative positions, are fairly easy to express as linear transformations\n", " of sines and cosines)." ] }, { "cell_type": "markdown", "metadata": { "id": "IDG2uOsaELU0" }, "source": [ "We superimpose this positional information on our embeddings.\n", "Note that because the model is residual,\n", "this position information will be by default preserved\n", "as it passes through the network,\n", "so it doesn't need to be repeatedly added." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what this positional encoding looks like in our codebase:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Zk62Q-a-1Ax" }, "outputs": [], "source": [ "PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n", "\n", "pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n", "\n", "ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n", "print(pe[:4, :8])" ] }, { "cell_type": "markdown", "metadata": { "id": "ep2ClIWvqDms" }, "source": [ "When we add the positional information to our embeddings,\n", "both the embedding information and the positional information\n", "is approximately preserved,\n", "as can be visually assessed below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PJuFjoCzC0Y4" }, "outputs": [], "source": [ "fake_embeddings = torch.randn_like(pe) * 0.5\n", "\n", "ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n", "\n", "fake_embeddings_with_pe = fake_embeddings + pe\n", "\n", "plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIzBxDkEmH8" }, "source": [ "A [similar technique](https://arxiv.org/abs/2103.06450)\n", "is used to also incorporate positional information into the image embeddings,\n", "which are flattened before being fed to the decoder." ] }, { "cell_type": "markdown", "metadata": { "id": "HC1N85wl8dvn" }, "source": [ "### Learn more about Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "lJwYxkjTk15t" }, "source": [ "We're only able to give a flavor and an intuition for Transformers here.\n", "\n", "To improve your grasp on the nuts and bolts, check out the\n", "[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n", "which is surprisingly approachable,\n", "as far as ML research papers go.\n", "The\n", "[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n", "adds code and commentary to the original paper,\n", "which makes it even more digestible.\n", "For something even friendlier, check out the\n", "[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n", "by Jay Alammar, which has an accompanying\n", "[video](https://youtu.be/-QH8fRhqFHM).\n", "\n", "Anthropic's work on\n", "[Transformer Circuits](https://transformer-circuits.pub/),\n", "summarized above, has some of the best material\n", "for building theoretical understanding\n", "and is still being updated with extensions and applications of the framework.\n", "The\n", "[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n", "are a great aid for checking and building your understanding.\n", "\n", "But they are fairly math-heavy.\n", "If you have more of a software engineering background, see\n", "Transformer Circuits co-author Nelson Elhage's blog post\n", "[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n", "\n", "For a gentler introduction to the intuition for Transformers,\n", "check out Brandon Rohrer's\n", "[Transformers From Scratch](https://e2eml.school/transformers.html)\n", "tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "qg7zntJES-aT" }, "source": [ "An aside:\n", "the matrix multiplications inside attention dominate\n", "the big-$O$ runtime of Transformers.\n", "So trying to make the attention mechanism more efficient, e.g. linear time,\n", "has generated a lot of research\n", "(review paper\n", "[here](https://arxiv.org/abs/2009.06732)).\n", "Despite drawing a lot of attention, so to speak,\n", "at the time of writing in mid-2022, these methods\n", "[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n", "so it isn't likely to be worth the effort to spend time learning about them\n", "unless you are a Transformer specialist." ] }, { "cell_type": "markdown", "metadata": { "id": "vCjXysEJ8g9_" }, "source": [ "# Using Transformers to read paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "KsfKWnOvqjva" }, "source": [ "Our simple convolutional model for text recognition from\n", "[Lab 02b](https://fsdl.me/lab02b-colab)\n", "could only handle cleanly-separated characters.\n", "\n", "It worked by sliding a LeNet-style CNN\n", "over the image,\n", "predicting a character for each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njLdzBqy-I90" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist_lines = text_recognizer.data.EMNISTLines()\n", "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "\n", "# for sliding, see the for loop over range(S)\n", "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "K0N6yDBQq8ns" }, "source": [ "But unfortunately for us, handwritten text\n", "doesn't come in neatly-separated characters\n", "of equal size, so we trained our model on synthetic data\n", "designed to work with that model." ] }, { "cell_type": "markdown", "metadata": { "id": "hiqUVbj0sxLr" }, "source": [ "Now that we have a better model,\n", "we can work with better data:\n", "paragraphs from the\n", "[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)." ] }, { "cell_type": "markdown", "metadata": { "id": "oizsOAcKs-dD" }, "source": [ "The cell uses our `LightningDataModule`\n", "to download and preprocess this data,\n", "writing results to disk.\n", "We can then spin up `DataLoader`s to give us batches.\n", "\n", "It can take several minutes to run the first time\n", "on commodity machines,\n", "with most time spent extracting the data.\n", "On subsequent runs,\n", "the time-consuming operations will not be repeated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uL9LHbjdsUbm" }, "outputs": [], "source": [ "iam_paragraphs = text_recognizer.data.IAMParagraphs()\n", "\n", "iam_paragraphs.prepare_data()\n", "iam_paragraphs.setup()\n", "xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n", "\n", "iam_paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "nBkFN9bbTm_S" }, "source": [ "Now that we've got a batch,\n", "let's take a look at some samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hqaps8yxtBhU" }, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import wandb\n", "\n", "\n", "def show(y):\n", " y = y.detach().cpu() # bring back from accelerator if it's being used\n", " return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"

\", \"\")\n", "\n", "idx = random.randint(0, len(xs))\n", "\n", "print(show(ys[idx]))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "4dT3UCNzTsoc" }, "source": [ "The `ResnetTransformer` model can run on this data\n", "if passed the `.config`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXL-vIGRr86D" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())" ] }, { "cell_type": "markdown", "metadata": { "id": "MMxa-oWyT01E" }, "source": [ "Our models are now big enough\n", "that we want to make use of GPU acceleration\n", "as much as we can,\n", "even when working on single inputs,\n", "so let's cast to the GPU if we have one." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YyUM8LgvW0w" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "rnt.to(device); xs = xs.to(device); ys = ys.to(device);" ] }, { "cell_type": "markdown", "metadata": { "id": "Y-E3UdD4zUJi" }, "source": [ "First, let's just pass it through the ResNet encoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-LUUtlvaxrvg" }, "outputs": [], "source": [ "resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n", " # resnet is designed for RGB images, so we replicate the input across channels 3 times" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eimgJ5dnywjg" }, "outputs": [], "source": [ "resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n", "plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n", "plt.axis(\"off\"); plt.colorbar(fraction=0.05);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These embeddings, though generated by random, untrained weights,\n", "are not entirely useless.\n", "\n", "Before neural networks could be effectively\n", "trained end to end,\n", "they were often used with frozen random weights\n", "eveywhere except the final layer\n", "(see e.g.\n", "[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n", "[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n", "these methods were still competitive, and\n", "[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n", "provide a\n", "[theoretical basis](https://arxiv.org/abs/2011.14522)\n", "for understanding their performance." ] }, { "cell_type": "markdown", "metadata": { "id": "ye6pW0ETzw2A" }, "source": [ "The final result, though, is repetitive gibberish --\n", "at the bare minimum, we need to train the unembedding/readout layer\n", "in order to get reasonable text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our architecture includes randomization with dropout,\n", "so repeated runs of the cell below will generate different outcomes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu3Pa7gLsFMo" }, "outputs": [], "source": [ "preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gvCXUbskv6XM" }, "outputs": [], "source": [ "print(show(preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without teacher forcing, runtime is also variable from iteration to iteration --\n", "the model stops when it generates an \"end sequence\" or padding token,\n", "which is not deterministic thanks to the dropout layers.\n", "For similar reasons, runtime is variable across inputs.\n", "\n", "The variable runtime of autoregressive generation\n", "is also not great for scaling.\n", "In a distributed setting, as required for large scale,\n", "forward passes need to be synced across devices,\n", "and if one device is generating a batch of much longer sequences,\n", "it will cause all the others to idle while they wait on it to finish." ] }, { "cell_type": "markdown", "metadata": { "id": "t76MSVRXV0V7" }, "source": [ "Let's turn our model into a `TransformerLitModel`\n", "so we can run with teacher forcing.\n", "\n", "> You may be wondering:\n", " why isn't teacher forcing part of the PyTorch module?\n", " In general, the `LightningModule`\n", " should encapsulate things that are needed in training, validation, and testing\n", " but not during inference.\n", " The teacher forcing trick fits this paradigm,\n", " even though it's so critical to what makes Transformers powerful. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8qrHRKHowdDi" }, "outputs": [], "source": [ "import text_recognizer.lit_models\n", "\n", "lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlNaFqR50Oid" }, "source": [ "Now we can use `.teacher_forward` if we also provide the target `ys`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpZdqXS5wn0F" }, "outputs": [], "source": [ "forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Zx9SmsN0QLT" }, "source": [ "This may not run faster than the `rnt.forward`,\n", "since generations are always the maximum possible length,\n", "but runtimes and output lengths are deterministic and constant." ] }, { "cell_type": "markdown", "metadata": { "id": "tu-XNYpi0Qvi" }, "source": [ "Forcing doesn't necessarily make our predictions better.\n", "They remain highly repetitive gibberish." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JcEgify9w0sv" }, "outputs": [], "source": [ "forcing_preds = torch.argmax(forcing_outs, dim=0)\n", "\n", "print(show(forcing_preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xn6GGNzc9a3o" }, "source": [ "## Training the `ResNetTransformer`" ] }, { "cell_type": "markdown", "metadata": { "id": "uvZYsuSyWUXe" }, "source": [ "We're finally ready to train this model on full paragraphs of handwritten text!" ] }, { "cell_type": "markdown", "metadata": { "id": "3cJwC7b720Sd" }, "source": [ "This is a more serious model --\n", "it's the one we use in the\n", "[deployed TextRecognizer application](http://fsdl.me/app).\n", "It's much larger than the models we've seen this far,\n", "so it can easily outstrip available compute resources,\n", "in particular GPU memory.\n", "\n", "To help, we use\n", "[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n", "which shrinks the size of most of our floats by half,\n", "which reduces memory consumption and can speed up computation.\n", "\n", "If your GPU has less than 8GB of available RAM,\n", "you'll see a \"CUDA out of memory\" `RuntimeError`,\n", "which is something of a\n", "[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n", "In this case, you can resolve it by reducing the `--batch_size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w1mXlhfy04Nm" }, "outputs": [], "source": [ "import torch\n", "\n", "gpus = int(torch.cuda.is_available())\n", "\n", "if gpus:\n", " !nvidia-smi\n", "else:\n", " print(\"watch out! working with this model on a typical CPU is not feasible\")" ] }, { "cell_type": "markdown", "metadata": { "id": "os1vW1rPZ1dy" }, "source": [ "Even with an okay GPU, like a\n", "[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n", "a single epoch of training can take over 10 minutes to run.\n", "We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n", "but you can remove those flags to see what full training looks like." ] }, { "cell_type": "markdown", "metadata": { "id": "vnF6dWFn4JlZ" }, "source": [ "It can take a long time (overnight)\n", "to train this model to decent performance on a single GPU,\n", "so we'll focus on other pieces for the exercises.\n", "\n", "> At the time of writing in mid-2022, the cheapest readily available option\n", "for training this model to decent performance on this dataset with this codebase\n", "comes out around $10, using\n", "[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n", "See, for example,\n", "[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n", "and associated experiment.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HufjdUZN0t4l", "scrolled": false }, "outputs": [], "source": [ "%%time\n", "# above %%magic times the cell, useful as a poor man's profiler\n", "\n", "%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n", " --gpus={gpus} --batch_size 16 --precision 16 \\\n", " --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2" ] }, { "cell_type": "markdown", "metadata": { "id": "L6fQ93ju3Iku" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "udb1Ekjx3L63" }, "source": [ "### 🌟 Try out gradient accumulation and other \"training tricks\"." ] }, { "cell_type": "markdown", "metadata": { "id": "kpqViB4p3Wfb" }, "source": [ "Larger batches are helpful not only for increasing parallelization\n", "and amortizing fixed costs\n", "but also for getting more reliable gradients.\n", "Larger batches give gradients with less noise\n", "and to a point, less gradient noise means faster convergence.\n", "\n", "But larger batches result in larger tensors,\n", "which take up more GPU memory,\n", "a resource that is tightly constrained\n", "and device-dependent.\n", "\n", "Does that mean we are limited in the quality of our gradients\n", "due to our machine size?\n", "\n", "Not entirely:\n", "look up the `--accumulate_grad_batches`\n", "argument to the `pl.Trainer`.\n", "You should be able to understand why\n", "it makes it possible to compute the same gradients\n", "you would find for a batch of size `k * N`\n", "on a machine that can only run batches up to size `N`.\n", "\n", "Accumulating gradients across batches is among the\n", "[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n", "Try some of them out!\n", "Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment." ] }, { "cell_type": "markdown", "metadata": { "id": "b2vtkmX830y3" }, "source": [ "### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n", "\n", "While training this model to actually fit the whole dataset is infeasible\n", "as a short exercise on commodity hardware,\n", "it's practical to train this model to memorize a batch of 16 examples.\n", "\n", "Passing `--overfit_batches 1` flag limits the number of training batches to 1\n", "and turns off\n", "[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n", "so that in each epoch, the model just sees the same single batch of data over and over again.\n", "\n", "At first, try training the model to a loss of `2.5` --\n", "it should be doable in 100 epochs or less,\n", "which is just a few minutes on a commodity GPU.\n", "\n", "Once you've got that working,\n", "crank up the number of epochs by a factor of 10\n", "and confirm that the loss continues to go down.\n", "\n", "Some tips:\n", "\n", "- Use `--limit_test_batches 0` to turn off testing.\n", "We don't need it because we don't care about generalization\n", "and it's relatively slow because it runs the model autoregressively.\n", "\n", "- Use `--help` and look through the model class args\n", "to find the arguments used to reduce model size.\n", "\n", "- By default, there's lots of regularization to prevent overfitting.\n", "Look through the args for the model class and data class\n", "for regularization knobs to turn off or down." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab03_transformers.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab04/notebooks/lab04_experiments.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 04: Experiment Management" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How experiment management brings observability to ML model development\n", "- Which features of experiment management we use in developing the Text Recognizer\n", "- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 4\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This lab contains a large number of embedded iframes\n", "that benefit from having a wide window.\n", "The cell below makes the notebook as wide as your browser window\n", "if `full_width` is set to `True`.\n", "Full width is the default behavior in Colab,\n", "so this cell is intended to improve the viewing experience in other Jupyter environments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Why experiment management?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand why we need experiment management for ML development,\n", "let's start by running an experiment.\n", "\n", "We'll train a new model on a new dataset,\n", "using the training script `training/run_experiment.py`\n", "introduced in [Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a CNN encoder and Transformer decoder, as in\n", "[Lab 03](https://fsdl.me/lab03-colab),\n", "but with some changes so we can iterate faster.\n", "We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n", "[Lab02b](https://fsdl.me/lab02b-colab),\n", "and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n", "from text_recognizer.data import IAMLines # processed version split into individual lines\n", "from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n", "\n", "\n", "print(IAM.__doc__)\n", "\n", "# uncomment a line below for details on either class\n", "# IAMLines?? \n", "# LineCNNTransformer??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below will train a model on 10% of the data for two epochs.\n", "\n", "It takes up to a few minutes to run on commodity hardware,\n", "including data download and preprocessing.\n", "As it's running, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%%time\n", "import torch\n", "\n", "\n", "gpus = int(torch.cuda.is_available()) \n", "\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the model trains, we're calculating lots of metrics --\n", "loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n", "and reporting them to the terminal.\n", "\n", "This is achieved by the built-in `.log` method\n", "([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n", "of the `LightningModule`,\n", "and it is a very straightforward way to get basic information about your experiment as it's running\n", "without leaving the context where you're running it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning to read\n", "[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n", "is something of a rite of passage for MLEs, but\n", "let's consider what we can't see here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We're missing all metric values except the most recent --\n", "we can see them as they stream in, but they're constantly overwritten.\n", "We also can't associate them with timestamps, steps, or epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We also don't see any system metrics.\n", "We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n", "without launching a separate process.\n", "And even if we do, those values will also not be saved and timestamped,\n", "so we can't correlate them with other things during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- As we continue to run experiments, changing code and opening new terminals,\n", "even the information we have or could figure out now will disappear.\n", "Say you spot a weird error message during training,\n", "but your session ends and the stdout is gone,\n", "so you don't know exactly what it was.\n", "Can you recreate the error?\n", "Which git branch and commit were you on?\n", "Did you have any uncommitted changes? Which arguments did you pass?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Also, model checkpoints containing the parameter values have been saved to disk.\n", "Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n", "As we run more and more experiments,\n", "we'll want to slice and dice them to see if,\n", "say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to save and log all of this information, and more, in order to make our model training\n", "[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n", "in short, so that we can understand, make decisions about, and debug our model training\n", "by looking at logs and source code, without having to recreate it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n", "1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n", "2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n", "3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n", "4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Local Experiment Tracking with Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n", "and it makes logging very easy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n", "\n", "It can also use a logger to push information elsewhere.\n", "\n", "By default, we use\n", "[TensorBoard](https://www.tensorflow.org/tensorboard)\n", "via the Lightning `TensorBoardLogger`,\n", "which has been saving results to the local disk.\n", "\n", "Let's find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest experiment's directory\n", "# by hand, you can just copy and paste it from the terminal\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n", "filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n", "latest_log" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!ls -lh {latest_log}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, we need to launch a TensorBoard server --\n", "much like we need to launch a Jupyter server to use Jupyter notebooks.\n", "\n", "The cells below load an extension that lets you use TensorBoard inside of a notebook\n", "the same way you'd use it from the command line, and then launch it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext tensorboard" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n", "\n", "port = 11717 # pick an open port on your machine\n", "host = \"0.0.0.0\" # allow connections from the internet\n", " # watch out! make sure you turn TensorBoard off\n", "\n", "%tensorboard --logdir {latest_log} --port {port} --host {host}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should see some charts of metrics over time along with some charting controls.\n", "\n", "You can click around in this interface and explore it if you'd like,\n", "but in the next section, we'll see that there are better tools for experiment management." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've run many experiments on this machine,\n", "you can see all of their results by pointing TensorBoard\n", "at the whole `lightning_logs` directory,\n", "rather than just one experiment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For large numbers of experiments, the management experience is not great --\n", "it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n", "\n", "It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n", "which are important workflows as applications mature and teams grow." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n", "\n", "If you run into any issues with the above cells failing to launch,\n", "especially across iterations of this lab, run this cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorboard.manager\n", "\n", "# get the process IDs for all tensorboard instances\n", "pids = [tb.pid for tb in tensorboard.manager.get_all()]\n", "\n", "done_with_tensorboard = False\n", "\n", "if done_with_tensorboard:\n", " # kill processes\n", " for pid in pids:\n", " !kill {pid} 2> /dev/null\n", " \n", " # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n", " !rm -rf {tensorboard.manager._get_info_dir()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment Management with Weights & Biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do we manage experiments when we hit the limits of local TensorBoard?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorBoard is powerful and flexible and very scalable,\n", "but running it requires engineering effort and babysitting --\n", "you're running a database, writing data to it,\n", "and layering a web application over it.\n", "\n", "This is a fairly common workflow for web developers,\n", "but not so much for ML engineers.\n", "\n", "You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n", "and it's as simple as running the command `tensorboard dev upload`\n", "pointed at your logging directory.\n", "\n", "But there are strict limits to this free service:\n", "1GB of tensor data and 1GB of binary data.\n", "A single Text Recognizer model checkpoint is ~100MB,\n", "and that's not particularly large for a useful model.\n", "\n", "Furthermore, all data is public,\n", "so if you upload the inputs and outputs of your model,\n", "anyone who finds the link can see them.\n", "\n", "Overall, tensorboard.dev works very well for certain academic and open projects\n", "but not for industrial ML." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To avoid that narrow permissions and limits issue,\n", "you could use [git LFS](https://git-lfs.github.com/)\n", "to track the binary data and tensor data,\n", "which is more likely to be sensitive than metrics.\n", "\n", "The Hugging Face ecosystem uses TensorBoard and git LFS.\n", "\n", "It includes the Hugging Face Hub, a git server much like GitHub,\n", "but designed first and foremost for collaboration on models and datasets,\n", "rather than collaboration on code.\n", "For example, the Hugging Face Hub\n", "[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n", "and officially has\n", "[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n", "avoiding the\n", "[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n", "that make using git LFS with GitHub expensive.\n", "\n", "However, we prefer to avoid mixing software version control and experiment management.\n", "\n", "First, using the Hub requires maintaining an additional git remote,\n", "which is a hard ask for many engineering teams.\n", "\n", "Secondly, git-style versioning is an awkward fit for logging --\n", "is it really sensible to create a new commit for each logging event while you're watching live?\n", "\n", "Instead, we prefer to use systems that solve experiment management with _databases_." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n", "The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n", "tool is [MLflow](https://github.com/mlflow/mlflow/)\n", "and there are a number of\n", "[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n", "\n", "These tools generally avoid any need to worry about hosting\n", "(unless data governance rules require a self-hosted version).\n", "\n", "For a sampling of publicly-posted opinions on experiment management tools,\n", "see these discussions from Reddit:\n", "\n", "- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n", "- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n", "\n", "Among these tools, the FSDL recommendation is\n", "[Weights & Biases](https://wandb.ai),\n", "which we believe offers\n", "- the best user experience, both in the Python SDKs and in the graphical interface\n", "- the best integrations with other tools,\n", "including\n", "[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n", "[Keras](https://docs.wandb.ai/guides/integrations/keras),\n", "[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n", "and even\n", "[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n", "and\n", "- the best tools for collaboration.\n", "\n", "Below, we'll take care to point out which logging and management features\n", "are available via generic interfaces in Lightning and which are W&B-specific." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "print(wandb.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding it to our experiment running code is extremely easy,\n", "relative to the features we get, which is\n", "one of the main selling points of W&B.\n", "\n", "We get most of our new experiment management features just by changing a single variable, `logger`, from\n", "`TensorboardLogger` to `WandbLogger`\n", "and adding two lines of code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see what each of these lines does for us below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this logger is built into and maintained by PyTorch Lightning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning.loggers import WandbLogger\n", "\n", "\n", "WandbLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to complete the rest of this notebook,\n", "you'll need a Weights & Biases account.\n", "\n", "As with GitHub the free tier, for personal, academic, and open source work,\n", "is very generous.\n", "\n", "The Text Recognizer project will fit comfortably within the free tier.\n", "\n", "Run the cell below and follow the prompts to log in or create an account or go\n", "[here](https://wandb.ai/signup)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb login" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the cell below to launch an experiment tracked with Weights & Biases.\n", "\n", "The experiment can take between 3 and 10 minutes to run.\n", "In that time, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n", " --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1\n", " \n", "last_expt = wandb.run\n", "\n", "wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see some new things in our output.\n", "\n", "For example, there's a note from `wandb` that the data is saved locally\n", "and also synced to their servers.\n", "\n", "There's a link to a webpage for viewing the logged data and a name for our experiment --\n", "something like `dandy-sunset-1`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local logging and cloud syncing happens with minimal impact on performance,\n", "because `wandb` launches a separate process to listen for events and upload them.\n", "\n", "That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Runs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, head to the link in the notebook output\n", "that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n", "\n", "There's no need to wait for training to finish.\n", "\n", "The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(last_expt.url)\n", "IFrame(last_expt.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have landed on the run page\n", "([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n", "which collects up all of the information for a single experiment into a collection of tabs.\n", "\n", "We'll work through these tabs from top to bottom.\n", "\n", "Each header is also a link to the documentation for a tab." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n", "This tab has an icon that looks like `(i)` or 🛈.\n", "\n", "The top section of this tab has high-level information about our run:\n", "- Timing information, like start time and duration\n", "- System hardware, hostname, and basic environment info\n", "- Git repository link and state\n", "\n", "This information is collected and logged automatically.\n", "\n", "The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n", "and summary metrics.\n", "\n", "Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n", "\n", "This tab has a line plot icon, something like 📈.\n", "\n", "It's also the default page you land on when looking at a W&B run.\n", "\n", "Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n", "\n", "Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n", "\n", "We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n", "This tab has computer chip icon.\n", "\n", "It contains\n", "- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n", "- CPU metrics: memory usage, utilization, thread counts\n", "- Disk and network I/O levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n", "This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n", "\n", "The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n", "This tab has an icon that looks like a stylized command prompt, `>_`.\n", "\n", "It contains information that was printed to the stdout.\n", "\n", "This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n", "\n", "Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelSummary\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n", "\n", "For more on Lightning `Callback`s, see\n", "[Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n", "This tab has a stylized document icon, something like 📄.\n", "\n", "You can use this tab to view any files saved with the `wandb.save`.\n", "\n", "For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n", "which we'll discuss shortly.\n", "\n", "But a few pieces of information automatically collected by W&B end up in this tab.\n", "\n", "Some highlights:\n", " - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n", " - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n", "This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n", "\n", "This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n", "\n", "We store two kinds of binary files\n", " - `run_table`s of model inputs and outputs\n", " - `model` checkpoints\n", "\n", "We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive Tables of Logged Media" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returning to the Charts tab,\n", "notice that we have model inputs and outputs logged in structured tables\n", "under the train, validation, and test sections.\n", "\n", "These tables are interactive as well\n", "([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n", "They support basic exploratory data analysis and are compatible with W&B's collaboration features." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to charts in our run page, these tables also have their own pages inside the W&B web app." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n", "table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n", "\n", "print(table_data_url)\n", "IFrame(src=table_data_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting this to work requires more effort and more W&B-specific code\n", "than the other features we've seen so far.\n", "\n", "We'll briefly explain the implementation here, for those who are interested.\n", "\n", "We use a custom Lightning `Callback`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n", "\n", "\n", "ImageToTextTableLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n", "\n", "The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n", "\n", "This behavior is sensible for metrics, which are low overhead, but not so much for media,\n", "where we'd rather subsample and avoid holding on to too much information.\n", "\n", "So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n", "\n", "The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.lit_models.base import BaseImageToTextLitModel\n", "\n", "BaseImageToTextLitModel.add_on_logged_batches??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Projects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Everything we've seen so far has been related to a single run or experiment.\n", "\n", "Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n", "\n", "We organize our runs into \"projects\" and view them on the W&B \"project page\" \n", "([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n", "\n", "By default in the Lightning integration, the project name is determined based on directory information.\n", "This default can be over-ridden in the code when creating a `WandbLogger`,\n", "but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what the project page looks like for a longer-running project with lots of experiments.\n", "\n", "The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n", "\n", "print(project_url)\n", "IFrame(src=project_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n", "\n", "We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Artifacts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Beyond logging metrics and metadata from runs,\n", "we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below pulls up all of the artifacts associated with the experiment we just ran." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Click on one of the `model` checkpoints -- the specific version doesn't matter.\n", "\n", "There are a number of tabs here.\n", "\n", "The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n", "\n", "The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n", "which are added by default by the `WandbLogger`.\n", "\n", "The \"Files\" tab contains the actual file contents of the artifact.\n", "\n", "On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n", "including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n", "\n", "You can click on these to explore the different versions and even directly compare them.\n", "\n", "If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Artifact storage is part of the W&B free tier.\n", "\n", "The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n", "\n", "The former is sufficient to store ~700 model checkpoints for the Text Recognizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can track your data storage and compare it to your limits at this URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n", "\n", "print(storage_tracker_url)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Programmatic Access" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also programmatically access our data and metadata via the `wandb` API\n", "([docs](https://docs.wandb.ai/guides/track/public-api-guide)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wb_api = wandb.Api()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run = wb_api.run(\"/\".join( # fetch a run given\n", " [last_expt.entity, # the user or org it was logged to\n", " last_expt.project, # the \"project\", usually one of several per repo/application\n", " last_expt.id] # and a unique ID\n", "))\n", "\n", "hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n", "\n", "hist.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hist.groupby(\"epoch\")[\"train/loss\"].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this includes the artifacts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# which artifacts where created and logged?\n", "artifacts = run.logged_artifacts()\n", "\n", "for artifact in artifacts:\n", " print(f\"artifact of type {artifact.type}: {artifact.name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thanks to our `ImageToTextTableLogger`,\n", "we can easily recreate training or validation data that came out of our `DataLoader`s,\n", "which is normally ephemeral:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n", "artifact_dir = Path(artifact.download(root=\"training/logs\"))\n", "image_dir = artifact_dir / \"media\" / \"images\"\n", "\n", "images = [path for path in image_dir.iterdir()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "from IPython.display import Image\n", "\n", "Image(str(random.choice(images)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Advanced W&B API Usage: MLOps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the strengths of a well-instrumented experiment tracking system is that it allows\n", "automatic relation of information:\n", "what were the inputs when this model's gradient spiked?\n", "Which models have been trained on this dataset,\n", "and what was their performance?\n", "\n", "Having access and automation around this information is necessary for \"MLOps\",\n", "which applies contemporary DevOps principles to ML projects." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cells below pull down the training data\n", "for the model currently running the FSDL Text Recognizer app.\n", "\n", "This is just intended as a demonstration of what's possible,\n", "so don't worry about understanding every piece of this,\n", "and feel free to skip past it.\n", "\n", "MLOps is still a nascent field, and these tools and workflows are likely to change.\n", "\n", "For example, just before the course launched, W&B released a\n", "[Model Registry layer](https://docs.wandb.ai/guides/models)\n", "on top of artifact logging that aims to improve the developer experience for these workflows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start from the same project we looked at in the project view:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n", "\n", "text_recognizer_project " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we search it for the text recognizer model currently being used in production:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# collect all versions of the text-recognizer ever put into production by...\n", "\n", "for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n", " if art_type.name == \"prod-ready\": # for the prod-ready type\n", " # and grabbing the text-recognizer\n", " production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n", "\n", "# and then get the one that's currently being tested in CI by...\n", "for text_recognizer in production_text_recognizers:\n", " if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n", " in_prod_text_recognizer = text_recognizer\n", "\n", "# view its metadata at the url or in the notebook\n", "in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n", "\n", "print(in_prod_text_recognizer_url)\n", "IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From its metadata, we can get information about how it was \"staged\" to be put into production,\n", "and in particular which model checkpoint was used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "staging_run = in_prod_text_recognizer.logged_by()\n", "\n", "training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n", "training_ckpt.name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That checkpoint was logged by a training experiment, which is available as metadata.\n", "\n", "We can look at the training run for that model, either here in the notebook or at its URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "training_run = training_ckpt.logged_by()\n", "print(training_run.url)\n", "IFrame(src=training_run.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And from there, we can access logs and metadata about training,\n", "confident that we are working with the model that is actually in production.\n", "\n", "For example, we can pull down the data we logged and analyze it locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_results = training_run.history(samples=10000)\n", "training_results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n", "training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 10\n", "training_results[\"validation/loss\"].dropna().iloc[10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The charts and webpages in Weights & Biases\n", "are substantially more useful than ephemeral stdouts or raw logs on disk.\n", "\n", "If you're spun up on the project,\n", "they accelerate debugging, exploration, and discovery.\n", "\n", "If not, they're not so much useful as they are overwhelming.\n", "\n", "We need to synthesize the raw logged data into information.\n", "This helps us communicate our work with other stakeholders,\n", "preserve knowledge and prevent repetition of work,\n", "and surface insights faster.\n", "\n", "These workflows are supported by the W&B Reports feature\n", "([docs here](https://docs.wandb.ai/guides/reports)),\n", "which mix W&B charts and tables with explanatory markdown text and embeds.\n", "\n", "Below are some common report patterns and\n", "use cases and examples of each." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the examples are from the FSDL Text Recognizer project.\n", "You can find more of them\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n", "where we've organized them into a report!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dashboard Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dashboards are a structured subset of the output from one or more experiments,\n", "designed for quickly surfacing issues or insights,\n", "like an accuracy or performance regression\n", "or a change in the data distribution.\n", "\n", "Use cases:\n", "- show the basic state of ongoing experiment\n", "- compare one experiment to another\n", "- select the most important charts so you can spin back up into context on a project more quickly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n", "\n", "IFrame(src=dashboard_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull Request Documentation Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In most software codebases,\n", "pull requests are a key focal point\n", "for units of work that combine\n", "short-term communication and long-term information tracking.\n", "\n", "In ML codebases, it's more difficult to bring\n", "sufficient information together to make PRs as useful.\n", "At FSDL, we like to add documentary\n", "reports with one or a small number of charts\n", "that connect logged information in the experiment management system\n", "to state in the version control software.\n", "\n", "Use cases:\n", "- communication of results within a team, e.g. code review\n", "- record-keeping that links pull request pages to raw logged info and makes it discoverable\n", "- improving confidence in PR correctness" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n", "\n", "IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Blog Post Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With sufficient effort, the logged data in the experiment management system\n", "can be made clear enough to be consumed,\n", "sufficiently contextualized to be useful outside the team, and\n", "even beautiful.\n", "\n", "The result is a report that's closer to a blog post than a dashboard or internal document.\n", "\n", "Use cases:\n", "- communication between teams or vertically in large organizations\n", "- external technical communication for branding and recruiting\n", "- attracting users or contributors\n", "\n", "Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma)\n", "and others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n", "\n", "IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many of our choices, like the depth of our network, the nonlinearities of our layers,\n", "and the learning rate and other parameters of our optimizer, cannot be\n", "([easily](https://arxiv.org/abs/1606.04474))\n", "chosen by descent of the gradient of a loss function.\n", "\n", "But these parameters that impact the values of the parameters\n", "we directly optimize with gradients, or _hyperparameters_,\n", "can still be optimized,\n", "essentially by trying options and selecting the values that worked best.\n", "\n", "In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n", "\n", "Expending more compute can squeeze small amounts of additional validation or test performance\n", "that makes for impressive results on leaderboards but typically doesn't translate\n", "into better user experience.\n", "\n", "In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n", "built into your other tooling.\n", "\n", "Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n", "([docs](https://docs.wandb.ai/guides/sweeps)).\n", "\n", "It also supports a number of more advanced tools, like\n", "[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n", "for early termination of poorly-performing runs.\n", "\n", "We can use the same training script and we don't need to run an optimization server.\n", "\n", "We just need to write a configuration yaml file\n", "([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n", "like the one below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training/simple-overfit-sweep.yaml\n", "# first we specify what we're sweeping\n", "# we specify a program to run\n", "program: training/run_experiment.py\n", "# we optionally specify how to run it, including setting default arguments\n", "command: \n", " - ${env}\n", " - ${interpreter}\n", " - ${program}\n", " - \"--wandb\"\n", " - \"--overfit_batches\"\n", " - \"1\"\n", " - \"--log_every_n_steps\"\n", " - \"25\"\n", " - \"--max_epochs\"\n", " - \"100\"\n", " - \"--limit_test_batches\"\n", " - \"0\"\n", " - ${args} # these arguments come from the sweep parameters below\n", "\n", "# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n", "method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n", "metric:\n", " name: train/loss\n", " goal: minimize\n", "parameters: \n", " # LineCNN hyperparameters\n", " window_width:\n", " values: [8, 16, 32, 64]\n", " window_stride:\n", " values: [4, 8, 16, 32]\n", " # Transformer hyperparameters\n", " tf_layers:\n", " values: [1, 2, 4, 8]\n", " # we can also fix some values, just like we set default arguments\n", " gpus:\n", " value: 1\n", " model_class:\n", " value: LineCNNTransformer\n", " data_class:\n", " value: IAMLines\n", " loss:\n", " value: transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the config we launch a \"controller\":\n", "a lightweight process that just decides what hyperparameters to try next\n", "and coordinates the heavierweight training.\n", "\n", "This lives on the W&B servers, so there are no headaches about opening ports for communication,\n", "cleaning up when it's done, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n", "simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we can launch an \"agent\" to follow the orders of the controller:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "\n", "# interrupt twice to terminate this cell if it's running too long,\n", "# it can be over 15 minutes with some hyperparameters\n", "\n", "!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n", "\n", "If not provided, the agent will run forever for random or Bayesian sweeps\n", "or until the sweep is terminated, which can be done from the W&B interface." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The agents make for a slick workflow for distributing sweeps across GPUs.\n", "\n", "We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n", "which controls which GPUs are accessible by a process, to launch\n", "parallel agents on separate GPUs on the same machine." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n", "# open another terminal\n", "CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n", "# and so on\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include optional exercises with the labs for learners who want to dive deeper on specific topics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟Contribute to a hyperparameter search." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n", "\n", "There are ~10,000,000 potential hyperparameter combinations,\n", "and each takes 30 minutes to test,\n", "so checking each possibility will take over 500 years of compute time.\n", "Best get cracking then!\n", "\n", "Run the cell below to pull up a dashboard and print the URL where you can check on the current status." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_entity = \"fullstackdeeplearning\"\n", "sweep_project = \"fsdl-line-recognizer-2022\"\n", "sweep_id = \"e0eo43eu\"\n", "sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n", "\n", "print(sweep_url)\n", "IFrame(src=sweep_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve information about the sweep from the API,\n", "including the hyperparameters being swept over." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparams = sweep_info.config[\"parameters\"]\n", "hyperparams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you'd like to contribute to this sweep,\n", "run the cell below after changing the count to a number greater than 0.\n", "\n", "Each iteration runs for 30 minutes if it does not crash,\n", "e.g. due to out-of-memory errors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "count = 0 # off by default, increase it to join in!\n", "\n", "if count:\n", " !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Write some manual logging in `wandb`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the FSDL Text Recognizer codebase,\n", "we almost exclusively log to W&B through Lightning,\n", "rather than through the `wandb` Python SDK.\n", "\n", "If you're interested in learning how to use W&B directly, e.g. with another training framework,\n", "try out this quick exercise that introduces the key players in the SDK." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n", "\n", "It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n", "\n", "Add W&B metric and artifact logging to this cell:\n", "- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n", "- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import random\n", "\n", "import wandb\n", "\n", "\n", "os.makedirs(\"logs\", exist_ok=True)\n", "\n", "project = \"trying-wandb\"\n", "config = {\"steps\": 50}\n", "\n", "\n", "with wandb.init(project=project, config=config) as run:\n", " steps = wandb.config[\"steps\"]\n", " \n", " for ii in range(steps):\n", " loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n", " \n", " with open(\"logs/hello.txt\", \"w\") as f:\n", " f.write(\"hello from wandb, my dudes!\")\n", " \n", " run_id = run.id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hello_run = wb_api.run(f\"{project}/{run_id}\")\n", "\n", "# check for logged loss data\n", "if \"loss\" not in hello_run.history().keys():\n", " print(\"loss not logged 🥲\")\n", "else:\n", " print(\"loss logged successfully 🥞\")\n", " if len(hello_run.history()[\"loss\"]) != steps:\n", " print(\"loss not logged on all steps 🥲\")\n", " else:\n", " print(\"loss logged on all steps 🥞\")\n", "\n", "artifacts = hello_run.logged_artifacts()\n", "\n", "# check for artifact with the right name\n", "if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n", " print(\"hello artifact not logged 🥲\")\n", "else:\n", " print(\"hello artifact logged successfully 🥞\")\n", " # check for the file inside the artifacts\n", " if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n", " print(\"could not find hello.txt 🥲\")\n", " else:\n", " print(\"hello.txt logged successfully 🥞\")\n", " \n", " \n", "hello_run" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you observe interesting phenomena during training,\n", "from promising hyperparameter combos to software bugs to strange model behavior,\n", "turn the charts into a W&B report and share it with the FSDL community or\n", "[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n", "with a link to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# check the sweep_info.config above to see the model and data hyperparameters\n", "# read through the --help output for all potential arguments\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n", " --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n", " --help # remove this line to run an experiment instead of printing help\n", " \n", "last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟🌟🌟 Add logging of tensor statistics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to logging model inputs and outputs as human-interpretable media,\n", "it's also frequently useful to see information about their numerical values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're interested in learning more about metric calculation and logging with Lightning,\n", "use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n", "to add tensor statistic logging to the `LineCNNTransformer`.\n", "\n", "`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n", "\n", "All three are useful, but start by adding just one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n", "- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n", " - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n", "- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n", " - Base your code on the calculation and logging of the `val_cer` metric.\n", " - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ", "collapsed_sections": [], "name": "lab04_experiments.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab04/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab04/text_recognizer/callbacks/__init__.py ================================================ from .model import ModelSizeLogger from .optim import LearningRateMonitor from . import imtotext from .imtotext import ImageToTextTableLogger as ImageToTextLogger ================================================ FILE: lab04/text_recognizer/callbacks/imtotext.py ================================================ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only try: import wandb has_wandb = True except ImportError: has_wandb = False from .util import check_and_warn class ImageToTextTableLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.on_train: if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "validation/predictions") def _log_image_text_table(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx] xs = [wandb.Image(x) for x in xs] rows = zip(*[xs, gt_strs, pred_strs]) columns = ["input_image", "ground_truth_string", "predicted_string"] trainer.logger.log_table(key=key, columns=columns, data=list(rows)) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) class ImageToTextCaptionLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "validation/predictions") @rank_zero_only def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "test/predictions") def _log_image_text_caption(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx] trainer.logger.log_image(key, xs, caption=pred_strs) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) ================================================ FILE: lab04/text_recognizer/callbacks/model.py ================================================ import os from pathlib import Path import tempfile import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_only import torch from .util import check_and_warn, logging try: import torchviz has_torchviz = True except ImportError: has_torchviz = False class ModelSizeLogger(pl.Callback): """Logs information about model size (in parameters and on disk).""" def __init__(self, print_size=True): super().__init__() self.print_size = print_size @rank_zero_only def on_fit_start(self, trainer, module): self._run(trainer, module) def _run(self, trainer, module): metrics = {} metrics["mb_disk"] = self.get_model_disksize(module) metrics["nparams"] = count_params(module) if self.print_size: print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB") metrics = {f"size/{key}": value for key, value in metrics.items()} trainer.logger.log_metrics(metrics, step=-1) @staticmethod def get_model_disksize(module): """Determine the model's size on disk by saving it to disk.""" with tempfile.NamedTemporaryFile() as f: torch.save(module.state_dict(), f) size_mb = os.path.getsize(f.name) / 1e6 return size_mb class GraphLogger(pl.Callback): """Logs a compute graph as an image.""" def __init__(self, output_key="logits"): super().__init__() self.graph_logged = False self.output_key = output_key if not has_torchviz: raise ImportError("GraphLogCallback requires torchviz." "") @rank_zero_only def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx): if not self.graph_logged: try: outputs = outputs[0][0]["extra"] self.log_graph(trainer, module, outputs[self.output_key]) except KeyError: logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}") self.graph_logged = True @staticmethod def log_graph(trainer, module, outputs): if check_and_warn(trainer.logger, "log_image", "graph"): return params_dict = dict(list(module.named_parameters())) graph = torchviz.make_dot(outputs, params=params_dict) graph.format = "png" fname = Path(trainer.logger.experiment.dir) / "graph" graph.render(fname) fname = str(fname.with_suffix("." + graph.format)) trainer.logger.log_image(key="graph", images=[fname]) def count_params(module): """Counts the number of parameters in a Torch Module.""" return sum(p.numel() for p in module.parameters()) ================================================ FILE: lab04/text_recognizer/callbacks/optim.py ================================================ import pytorch_lightning as pl KEY = "optimizer" class LearningRateMonitor(pl.callbacks.LearningRateMonitor): """Extends Lightning's LearningRateMonitor with a prefix. Logs the learning rate during training. See the docs for pl.callbacks.LearningRateMonitor for details. """ def _add_prefix(self, *args, **kwargs) -> str: return f"{KEY}/" + super()._add_prefix(*args, **kwargs) ================================================ FILE: lab04/text_recognizer/callbacks/util.py ================================================ import logging logging.basicConfig(level=logging.WARNING) def check_and_warn(logger, attribute, feature): if not hasattr(logger, attribute): warn_no_attribute(feature, attribute) return True def warn_no_attribute(blocked_feature, missing_attribute): logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.") ================================================ FILE: lab04/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines from .iam_paragraphs import IAMParagraphs from .iam_lines import IAMLines ================================================ FILE: lab04/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab04/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab04/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab04/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab04/text_recognizer/data/iam.py ================================================ """Class for loading the IAM handwritten text dataset, which encompasses both paragraphs and lines, plus utilities.""" from pathlib import Path from typing import Any, cast, Dict, List, Optional import zipfile from boltons.cacheutils import cachedproperty from defusedxml import ElementTree from PIL import Image, ImageOps import toml from text_recognizer import util from text_recognizer.data.base_data_module import _download_raw_dataset, load_and_print_info import text_recognizer.metadata.iam as metadata from text_recognizer.metadata.iam_paragraphs import NEW_LINE_TOKEN METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME EXTRACTED_DATASET_DIRNAME = metadata.EXTRACTED_DATASET_DIRNAME class IAM: """A dataset of images of handwritten text written on a form underneath a typewritten prompt. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database Images are identified by their "form ID". These IDs are used to separate train, validation and test splits, as keys for dictonaries returning label and image crop region data, and more. The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (LWITLRT): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self): self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self): if self.xml_filenames: return filename = _download_raw_dataset(self.metadata, DL_DATA_DIRNAME) # type: ignore _extract_raw_dataset(filename, DL_DATA_DIRNAME) def load_image(self, id: str) -> Image.Image: """Load and return an image of an entire IAM form. The image is grayscale with white text on black background. This image will have the printed prompt text at the top, above the handwritten text. Images of individual words or lines and of whole paragraphs can be cropped out using the relevant crop region data. """ image = util.read_image_pil(self.form_filenames_by_id[id], grayscale=True) image = ImageOps.invert(image) return image def __repr__(self): """Print info about the dataset.""" info = ["IAM Dataset"] info.append(f"Total Images: {len(self.xml_filenames)}") info.append(f"Total Test Images: {len(self.test_ids)}") info.append(f"Total Paragraphs: {len(self.paragraph_string_by_id)}") num_lines = sum(len(line_regions) for line_regions in self.line_regions_by_id.items()) info.append(f"Total Lines: {num_lines}") return "\n\t".join(info) @cachedproperty def all_ids(self): """A list of all form IDs.""" return sorted([f.stem for f in self.xml_filenames]) @cachedproperty def ids_by_split(self): return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids} @cachedproperty def split_by_id(self): """A dictionary mapping form IDs to their split according to IAM Lines LWITLRT.""" split_by_id = {id_: "train" for id_ in self.train_ids} split_by_id.update({id_: "val" for id_ in self.validation_ids}) split_by_id.update({id_: "test" for id_ in self.test_ids}) return split_by_id @cachedproperty def train_ids(self): """A list of form IDs which are in the IAM Lines LWITLRT training set.""" return list(set(self.all_ids) - (set(self.test_ids) | set(self.validation_ids))) @cachedproperty def test_ids(self): """A list of form IDs from the IAM Lines LWITLRT test set.""" return _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/testset.txt") @property def xml_filenames(self) -> List[Path]: """A list of the filenames of all .xml files, which contain label information.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @cachedproperty def validation_ids(self): """A list of form IDs from IAM Lines LWITLRT validation sets 1 and 2.""" val_ids = _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset1.txt") val_ids.extend(_get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset2.txt")) return val_ids @property def form_filenames(self) -> List[Path]: """A list of the filenames of all .jpg files, which contain images of IAM forms.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def xml_filenames_by_id(self): """A dictionary mapping form IDs to their XML label information files.""" return {filename.stem: filename for filename in self.xml_filenames} @property def form_filenames_by_id(self): """A dictionary mapping form IDs to their JPEG images.""" return {filename.stem: filename for filename in self.form_filenames} @cachedproperty def line_strings_by_id(self): """A dict mapping an IAM form id to its list of line texts.""" return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def line_regions_by_id(self): """A dict mapping an IAM form id to its list of line image crop regions.""" return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def paragraph_string_by_id(self): """A dict mapping an IAM form id to its paragraph text.""" return {id: NEW_LINE_TOKEN.join(line_strings) for id, line_strings in self.line_strings_by_id.items()} @cachedproperty def paragraph_region_by_id(self): """A dict mapping an IAM form id to its paragraph image crop region.""" return { id: { "x1": min(region["x1"] for region in line_regions), "y1": min(region["y1"] for region in line_regions), "x2": max(region["x2"] for region in line_regions), "y2": max(region["y2"] for region in line_regions), } for id, line_regions in self.line_regions_by_id.items() } def _extract_raw_dataset(filename: Path, dirname: Path) -> None: print("Extracting IAM data") with util.temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zip_file: zip_file.extractall() def _get_ids_from_lwitlrt_split_file(filename: str) -> List[str]: """Get the ids from Large Writer Independent Text Line Recognition Task (LWITLRT) data split file.""" with open(filename, "r") as f: line_ids_str = f.read() line_ids = line_ids_str.split("\n") page_ids = list({"-".join(line_id.split("-")[:2]) for line_id in line_ids if line_id}) return page_ids def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace " with ".""" xml_line_elements = _get_line_elements_from_xml_file(filename) return [_get_text_from_xml_element(el) for el in xml_line_elements] def _get_text_from_xml_element(xml_element: Any) -> str: """Extract text from any XML element.""" return xml_element.attrib["text"].replace(""", '"') def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get the line region dict for each line.""" xml_line_elements = _get_line_elements_from_xml_file(filename) line_regions = [ cast(Dict[str, int], _get_region_from_xml_element(xml_elem=el, xml_path="word/cmp")) for el in xml_line_elements ] assert any(region is not None for region in line_regions), "Line regions cannot be None" # next_line_region["y1"] - prev_line_region["y2"] can be negative due to overlapping characters line_gaps_y = [ max(next_line_region["y1"] - prev_line_region["y2"], 0) for next_line_region, prev_line_region in zip(line_regions[1:], line_regions[:-1]) ] post_line_gaps_y = line_gaps_y + [2 * metadata.LINE_REGION_PADDING] pre_line_gaps_y = [2 * metadata.LINE_REGION_PADDING] + line_gaps_y return [ { "x1": region["x1"] - metadata.LINE_REGION_PADDING, "x2": region["x2"] + metadata.LINE_REGION_PADDING, "y1": region["y1"] - min(metadata.LINE_REGION_PADDING, pre_line_gaps_y[i] // 2), "y2": region["y2"] + min(metadata.LINE_REGION_PADDING, post_line_gaps_y[i] // 2), } for i, region in enumerate(line_regions) ] def _get_line_elements_from_xml_file(filename: str) -> List[Any]: """Get all line xml elements from xml file.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec return xml_root_element.findall("handwritten-part/line") def _get_region_from_xml_element(xml_elem: Any, xml_path: str) -> Optional[Dict[str, int]]: """ Get region from input xml element. The region is downsampled because the stored images are also downsampled. Parameters ---------- xml_elem xml element can be a line or word element with x, y, width, and height attributes xml_path should be "word/cmp" if xml_elem is a line element, else "cmp" """ unit_elements = xml_elem.findall(xml_path) if not unit_elements: return None return { "x1": min(int(el.attrib["x"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y1": min(int(el.attrib["y"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "x2": max(int(el.attrib["x"]) + int(el.attrib["width"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y2": max(int(el.attrib["y"]) + int(el.attrib["height"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, } if __name__ == "__main__": load_and_print_info(IAM) ================================================ FILE: lab04/text_recognizer/data/iam_lines.py ================================================ """A dataset of lines of handwritten text derived from the IAM dataset.""" import argparse import json from pathlib import Path from typing import Sequence import numpy as np from PIL import Image, ImageFile from text_recognizer import util from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.line import IAMLineStem ImageFile.LOAD_TRUNCATED_IMAGES = True PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR class IAMLines(BaseDataModule): """Lines of text pulled from the IAM Handwriting database.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true") == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = IAMLineStem() self.trainval_transform = IAMLineStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if PROCESSED_DATA_DIRNAME.exists(): return print("Cropping IAM line regions...") iam = IAM() iam.prepare_data() crops_train, labels_train = generate_line_crops_and_labels(iam, "train") crops_val, labels_val = generate_line_crops_and_labels(iam, "val") crops_test, labels_test = generate_line_crops_and_labels(iam, "test") shapes = np.array([crop.size for crop in crops_train + crops_val + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] print("Saving images, labels, and statistics...") save_images_and_labels(crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_val, labels_val, "val", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt", "w") as file: file.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt") as file: max_aspect_ratio = float(file.read()) image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) assert image_width <= metadata.IMAGE_WIDTH if stage == "fit" or stage is None: x_train, labels_train = load_processed_crops_and_labels("train", PROCESSED_DATA_DIRNAME) y_train = convert_strings_to_labels(labels_train, self.inverse_mapping, length=self.output_dims[0]) self.data_train = BaseDataset(x_train, y_train, transform=self.trainval_transform) x_val, labels_val = load_processed_crops_and_labels("val", PROCESSED_DATA_DIRNAME) y_val = convert_strings_to_labels(labels_val, self.inverse_mapping, length=self.output_dims[0]) self.data_val = BaseDataset(x_val, y_val, transform=self.trainval_transform) # quick check: do we have the right sequence lengths? assert self.output_dims[0] >= max([len(_) for _ in labels_train]) + 2 # Add 2 for start/end tokens. assert self.output_dims[0] >= max([len(_) for _ in labels_val]) + 2 # Add 2 for start/end tokens. if stage == "test" or stage is None: x_test, labels_test = load_processed_crops_and_labels("test", PROCESSED_DATA_DIRNAME) y_test = convert_strings_to_labels(labels_test, self.inverse_mapping, length=self.output_dims[0]) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) assert self.output_dims[0] >= max([len(_) for _ in labels_test]) + 2 def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Lines Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def generate_line_crops_and_labels(iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR): """Create both cropped lines and associated labels from IAM, with resizing by default""" crops, labels = [], [] for iam_id in iam.ids_by_split[split]: labels += iam.line_strings_by_id[iam_id] image = iam.load_image(iam_id) for line in iam.line_regions_by_id[iam_id]: coords = [line[point] for point in ["x1", "y1", "x2", "y2"]] crop = image.crop(coords) crop = resize_image(crop, scale_factor=scale_factor) crops.append(crop) assert len(crops) == len(labels) return crops, labels def save_images_and_labels(crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path): (data_dirname / split).mkdir(parents=True, exist_ok=True) with open(data_dirname / split / "_labels.json", "w") as f: json.dump(labels, f) for ind, crop in enumerate(crops): crop.save(data_dirname / split / f"{ind}.png") def load_processed_crops_and_labels(split: str, data_dirname: Path): """Load line crops and labels for given split from processed directory.""" crops = load_processed_line_crops(split, data_dirname) labels = load_processed_line_labels(split, data_dirname) assert len(crops) == len(labels) return crops, labels def load_processed_line_crops(split: str, data_dirname: Path): """Load line crops for given split from processed directory.""" crop_filenames = sorted((data_dirname / split).glob("*.png"), key=lambda filename: int(Path(filename).stem)) crops = [util.read_image_pil(filename, grayscale=True) for filename in crop_filenames] return crops def load_processed_line_labels(split: str, data_dirname: Path): """Load line labels for given split from processed directory.""" with open(data_dirname / split / "_labels.json") as file: labels = json.load(file) return labels if __name__ == "__main__": load_and_print_info(IAMLines) ================================================ FILE: lab04/text_recognizer/data/iam_paragraphs.py ================================================ """IAM Paragraphs Dataset class.""" import argparse import json from pathlib import Path from typing import Callable, Dict, Optional, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.paragraph import ParagraphStem IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME class IAMParagraphs(BaseDataModule): """IAM Handwriting database paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true").lower() == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = ParagraphStem() self.trainval_transform = ParagraphStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): return rank_zero_info( "IAMParagraphs.prepare_data: Cropping IAM paragraph regions and saving them along with labels..." ) iam = IAM() iam.prepare_data() properties = {} for split in ["train", "val", "test"]: crops, labels = get_paragraph_crops_and_labels(iam=iam, split=split) save_crops_and_labels(crops=crops, labels=labels, split=split) properties.update( { id_: { "crop_shape": crops[id_].size[::-1], "label_length": len(label), "num_lines": _num_lines(label), } for id_, label in labels.items() } ) with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: def _load_dataset(split: str, transform: Callable) -> BaseDataset: crops, labels = load_processed_crops_and_labels(split) Y = convert_strings_to_labels(strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]) return BaseDataset(crops, Y, transform=transform) rank_zero_info(f"IAMParagraphs.setup({stage}): Loading IAM paragraph regions and lines...") validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) if stage == "fit" or stage is None: self.data_train = _load_dataset(split="train", transform=self.trainval_transform) self.data_val = _load_dataset(split="val", transform=self.transform) if stage == "test" or stage is None: self.data_test = _load_dataset(split="test", transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def validate_input_and_output_dimensions( input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] ) -> None: """Validate input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] # Add 2 because of start and end tokens assert output_dims is not None and output_dims[0] >= properties["label_length"]["max"] + 2 def get_paragraph_crops_and_labels( iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR ) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: """Create IAM paragraph crops and labels for a given split, with resizing.""" crops = {} labels = {} for iam_id in iam.ids_by_split[split]: image = iam.load_image(iam_id) para_region = iam.paragraph_region_by_id[iam_id] crops[iam_id] = image.crop([para_region[_] for _ in ["x1", "y1", "x2", "y2"]]) crops[iam_id] = resize_image(crops[iam_id], scale_factor=scale_factor) labels[iam_id] = iam.paragraph_string_by_id[iam_id] assert len(crops) == len(labels) return crops, labels def save_crops_and_labels(crops: Dict[str, Image.Image], labels: Dict[str, str], split: str): """Save crops, labels and shapes of crops of a split.""" (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with open(_labels_filename(split), "w") as f: json.dump(labels, f, indent=4) for id_, crop in crops.items(): crop.save(_crop_filename(id_, split)) def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: """Load processed crops and labels for given split.""" with open(_labels_filename(split), "r") as f: labels = json.load(f) sorted_ids = sorted(labels.keys()) ordered_crops = [Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids] ordered_labels = [labels[id_] for id_ in sorted_ids] assert len(ordered_crops) == len(ordered_labels) return ordered_crops, ordered_labels def get_dataset_properties() -> dict: """Return properties describing the overall dataset.""" with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: properties = json.load(f) def _get_property_values(key: str) -> list: return [_[key] for _ in properties.values()] crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count(NEW_LINE_TOKEN) + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs) ================================================ FILE: lab04/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab04/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab04/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab04/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel from .transformer import TransformerLitModel ================================================ FILE: lab04/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy from .metrics import CharacterErrorRate OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def add_on_first_batch(self, metrics, outputs, batch_idx): if batch_idx == 0: outputs.update(metrics) def add_on_logged_batches(self, metrics, outputs): if self.is_logged_batch: outputs.update(metrics) def is_logged_batch(self): if self.trainer is None: return False else: return self.trainer._logger_connector.should_update_logs class BaseImageToTextLitModel(BaseLitModel): # pylint: disable=too-many-ancestors """Base class for ImageToText models in PyTorch Lightning.""" def __init__(self, model, args: argparse.Namespace = None): super().__init__(model, args) self.model = model self.args = vars(args) if args is not None else {} self.inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} self.start_index = self.inverse_mapping[""] self.end_index = self.inverse_mapping[""] self.padding_index = self.inverse_mapping["

"] self.ignore_tokens = [self.start_index, self.end_index, self.padding_index] self.val_cer = CharacterErrorRate(self.ignore_tokens) self.test_cer = CharacterErrorRate(self.ignore_tokens) ================================================ FILE: lab04/text_recognizer/lit_models/metrics.py ================================================ """Special-purpose metrics for tracking our model performance.""" from typing import Sequence import torch import torchmetrics class CharacterErrorRate(torchmetrics.CharErrorRate): """Character error rate metric, allowing for tokens to be ignored.""" def __init__(self, ignore_tokens: Sequence[int], *args): super().__init__(*args) self.ignore_tokens = set(ignore_tokens) def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()] targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()] super().update(preds_l, targets_l) def test_character_error_rate(): metric = CharacterErrorRate([0, 1]) X = torch.tensor( [ [0, 2, 2, 3, 3, 1], # error will be 0 [0, 2, 1, 1, 1, 1], # error will be .75 [0, 2, 2, 4, 4, 1], # error will be .5 ] ) Y = torch.tensor( [ [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], ] ) metric(X, Y) assert metric.compute() == sum([0, 0.75, 0.5]) / 3 if __name__ == "__main__": test_character_error_rate() ================================================ FILE: lab04/text_recognizer/lit_models/transformer.py ================================================ """An encoder-decoder Transformer model""" from typing import List, Sequence import torch from .base import BaseImageToTextLitModel from .util import replace_after class TransformerLitModel(BaseImageToTextLitModel): """ Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module. The module must implement an encode and decode method, and the forward method should be the forward pass during production inference. """ def __init__(self, model, args=None): super().__init__(model, args) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index) def forward(self, x): return self.model(x) def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Uses provided sequence y as guide for non-autoregressive encoding-decoding of x. Parameters ---------- x Batch of images to be encoded. See self.model.encode for shape information. y Batch of ground truth output sequences. Returns ------- torch.Tensor (B, C, Sy) logits """ x = self.model.encode(x) output = self.model.decode(x, y) # (Sy, B, C) return output.permute(1, 2, 0) # (B, C, Sy) def training_step(self, batch, batch_idx): x, y = batch logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("train/loss", loss) outputs = {"loss": loss} if self.is_logged_batch(): preds = self.get_preds(logits) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs}) return outputs def validation_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("test/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def map(self, ks: Sequence[int], ignore: bool = True) -> str: """Maps an iterable of integers to a string using the lit model's mapping.""" if ignore: return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens]) else: return "".join([self.mapping[k] for k in ks]) def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]: """Maps a list of lists of integers to a list of strings using the lit model's mapping.""" return [self.map(k, ignore) for k in ks] def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor: """Converts logit-like Tensors into prediction indices, optionally overwritten after end token index. Parameters ---------- logitlikes (B, C, Sy) Tensor with classes as second dimension. The largest value is the one whose index we will return. Logits, logprobs, and probs are all acceptable. replace_after_end Whether to replace values after the first appearance of the end token with the padding token. Returns ------- torch.Tensor (B, Sy) Tensor of integers in [0, C-1] representing predictions. """ raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy) if replace_after_end: return replace_after(raw, self.end_index, self.padding_index) # (B, Sy) else: return raw # (B, Sy) ================================================ FILE: lab04/text_recognizer/lit_models/util.py ================================================ from typing import Union import torch def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: """Return indices of first appearance of element in x, collapsing along dim. Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 Parameters ---------- x One or two-dimensional Tensor to search for element. element Item to search for inside x. dim Dimension of Tensor to collapse over. Returns ------- torch.Tensor Indices where element occurs in x. If element is not found, return length of x along dim. One dimension smaller than x. Raises ------ ValueError if x is not a 1 or 2 dimensional Tensor Examples -------- >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3) tensor([2, 1, 3, 0]) >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0) tensor(0) """ if x.dim() > 2 or x.dim() == 0: raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}") matches = x == element first_appearance_mask = (matches.cumsum(dim) == 1) & matches does_match, match_index = first_appearance_mask.max(dim) first_inds = torch.where(does_match, match_index, x.shape[dim]) return first_inds def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor: """Replace all values in each row of 2d Tensor x after the first appearance of element with replace. Parameters ---------- x Two-dimensional Tensor (shape denoted (B, S)) to replace values in. element Item to search for inside x. replace Item that replaces entries that appear after element. Returns ------- outs New Tensor of same shape as x with values after element replaced. Examples -------- >>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4) tensor([[1, 2, 3], [2, 3, 4], [1, 1, 1], [3, 4, 4]]) """ first_appearances = first_appearance(x, element, dim=1) # (B,) indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,) outs = torch.where( indices[None, :] <= first_appearances[:, None], # if index is before first appearance x, # return the value from x replace, # otherwise, return the replacement value ) return outs # (B, S) ================================================ FILE: lab04/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab04/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab04/text_recognizer/metadata/iam.py ================================================ import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates ================================================ FILE: lab04/text_recognizer/metadata/iam_lines.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines" IMAGE_SCALE_FACTOR = 2 CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (89, 1) MAPPING = emnist.MAPPING ================================================ FILE: lab04/text_recognizer/metadata/iam_paragraphs.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN] IMAGE_SCALE_FACTOR = 2 IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640 IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) MAX_LABEL_LENGTH = 682 DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) ================================================ FILE: lab04/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab04/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab04/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple from .resnet_transformer import ResnetTransformer from .line_cnn_transformer import LineCNNTransformer ================================================ FILE: lab04/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab04/text_recognizer/models/line_cnn.py ================================================ """Basic building blocks for convolutional models over lines of text.""" import argparse import math from typing import Any, Dict, Tuple, Union import torch from torch import nn import torch.nn.functional as F # Common type hints Param2D = Union[int, Tuple[int, int]] CONV_DIM = 32 FC_DIM = 512 FC_DROPOUT = 0.2 WINDOW_WIDTH = 16 WINDOW_STRIDE = 8 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__( self, input_channels: int, output_channels: int, kernel_size: Param2D = 3, stride: Param2D = 1, padding: Param2D = 1, ) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class LineCNN(nn.Module): """ Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits """ def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.args = vars(args) if args is not None else {} self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] _C, H, _W = data_config["input_dims"] conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) # Input is (1, H, W) self.convs = nn.Sequential( ConvBlock(1, conv_dim), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim, stride=2), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim * 2, stride=2), ConvBlock(conv_dim * 2, conv_dim * 2), ConvBlock(conv_dim * 2, conv_dim * 4, stride=2), ConvBlock(conv_dim * 4, conv_dim * 4), ConvBlock( conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0 ), ) self.fc1 = nn.Linear(fc_dim, fc_dim) self.dropout = nn.Dropout(fc_dropout) self.fc2 = nn.Linear(fc_dim, self.num_classes) self._init_weights() def _init_weights(self): """ Initialize weights in a better way than default. See https://github.com/pytorch/pytorch/issues/18182 """ for m in self.modules(): if type(m) in { nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear, }: nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu") if m.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(m.bias, -bound, bound) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the LineCNN to a black-and-white input image. Parameters ---------- x (B, 1, H, W) input image Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and self.window_width C is self.num_classes """ _B, _C, _H, _W = x.shape x = self.convs(x) # (B, FC_DIM, 1, Sx) x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM) x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM) x = self.dropout(x) x = self.fc2(x) # (B, S, C) x = x.permute(0, 2, 1) # -> (B, C, S) if self.limit_output_length: x = x[:, :, : self.output_length] return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab04/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab04/text_recognizer/models/line_cnn_transformer.py ================================================ """Model that combines a LineCNN with a Transformer model for text prediction.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .line_cnn import LineCNN from .transformer_util import generate_square_subsequent_mask, PositionalEncoding TF_DIM = 256 TF_FC_DIM = 256 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 class LineCNNTransformer(nn.Module): """Process the line through a CNN and process the resulting sequence with a Transformer decoder.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # Instantiate LineCNN with "num_classes" set to self.dim data_config_for_line_cnn = {**data_config} data_config_for_line_cnn["mapping"] = list(range(self.dim)) self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args) # LineCNN outputs (B, E, S) log probs, with E == dim self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.pos_encoder = PositionalEncoding(d_model=self.dim) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (Sx, B, E) logits """ x = self.line_cnn(x) # (B, E, Sx) x = x * math.sqrt(self.dim) x = x.permute(2, 0, 1) # (Sx, B, E) x = self.pos_encoder(x) # (Sx, B, E) return x def decode(self, x, y): """Decode a batch of encoded images x using preceding ground truth y. Parameters ---------- x (Sx, B, E) image encoded as a sequence y (B, Sy) with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) logits """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output def forward(self, x: torch.Tensor) -> torch.Tensor: """Predict sequences of tokens from input images auto-regressively. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1:] # Set the last output token # Set all tokens after end token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) @staticmethod def add_to_argparse(parser): LineCNN.add_to_argparse(parser) parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab04/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab04/text_recognizer/models/resnet_transformer.py ================================================ """Model combining a ResNet with a Transformer for image-to-sequence tasks.""" import argparse import math from typing import Any, Dict import torch from torch import nn import torchvision from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage TF_DIM = 256 TF_FC_DIM = 1024 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 RESNET_DIM = 512 # hard-coded class ResnetTransformer(nn.Module): """Pass an image through a Resnet and decode the resulting embedding with a Transformer.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) self.mapping = data_config["mapping"] inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # ## Encoder part - should output vector sequence of length self.dim per sample resnet = torchvision.models.resnet18(weights=None) self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) # encoder_projection will output (B, dim, _H, _W) logits self.enc_pos_encoder = PositionalEncodingImage( d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] ) # Max (Ho, Wo) # ## Decoder part self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def forward(self, x: torch.Tensor) -> torch.Tensor: """Autoregressively produce sequences of labels from input images. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- output_tokens (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1] # Set the last output token # Early stopping of prediction loop to speed up prediction if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): break # Set all tokens after end or padding token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") if self.encoder_projection.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.encoder_projection.bias, -bound, bound) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- (Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps """ _B, C, _H, _W = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo return x def decode(self, x, y): """Decode a batch of encoded images x with guiding sequences y. During autoregressive inference, the guiding sequence will be previous predictions. During training, the guiding sequence will be the ground truth. Parameters ---------- x (Sx, B, E) images encoded as sequences of embeddings y (B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) batch of logit sequences """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.dec_pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output @staticmethod def add_to_argparse(parser): parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab04/text_recognizer/models/transformer_util.py ================================================ """Position Encoding and other utilities for Transformers.""" import math import torch from torch import Tensor import torch.nn as nn class PositionalEncodingImage(nn.Module): """ Module used to add 2-D positional encodings to the feature-map produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor: pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """pytorch.nn.module.forward""" # x.shape = (B, d_model, H, W) assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding(torch.nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None: super().__init__() self.dropout = torch.nn.Dropout(p=dropout) pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_len: int) -> torch.Tensor: pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: # x.shape = (S, B, d_model) assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) def generate_square_subsequent_mask(size: int) -> torch.Tensor: """Generate a triangular (size, size) mask.""" mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask ================================================ FILE: lab04/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab04/text_recognizer/stems/line.py ================================================ import random from PIL import Image from torchvision import transforms import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.image import ImageStem class LineStem(ImageStem): """A stem for handling images containing a line of text.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.5, 1)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "translate": (0, 0.05), "scale": (0.4, 1.1), "shear": (-40, 50), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } if augment: self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] ) class IAMLineStem(ImageStem): """A stem for handling images containing lines of text from the IAMLines dataset.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() def embed_crop(crop, augment=augment): # crop is PIL.image of dtype="L" (so values range from 0 -> 255) image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) # Resize crop crop_width, crop_height = crop.size new_crop_height = metadata.IMAGE_HEIGHT new_crop_width = int(new_crop_height * (crop_width / crop_height)) if augment: # Add random stretching new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR) # Embed in the image x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) y = metadata.IMAGE_HEIGHT - new_crop_height image.paste(crop_resized, (x, y)) return image if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.8, 1.6)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 1, "shear": (-30, 20), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } pil_transforms_list = [transforms.Lambda(embed_crop)] if augment: pil_transforms_list += [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] self.pil_transforms = transforms.Compose(pil_transforms_list) ================================================ FILE: lab04/text_recognizer/stems/paragraph.py ================================================ """IAMParagraphs Stem class.""" import torchvision.transforms as transforms import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.image import ImageStem IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH class ParagraphStem(ImageStem): """A stem for handling images that contain a paragraph of text.""" def __init__( self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None, random_perspective_kwargs=None, gaussian_blur_kwargs=None, sharpness_kwargs=None, ): super().__init__() if not augment: self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)]) else: if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "shear": 6, "scale": (0.95, 1), "interpolation": transforms.InterpolationMode.BILINEAR, } if random_perspective_kwargs is None: random_perspective_kwargs = { "distortion_scale": 0.2, "p": 0.5, "interpolation": transforms.InterpolationMode.BILINEAR, } if gaussian_blur_kwargs is None: gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} if sharpness_kwargs is None: sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} # IMAGE_SHAPE is (576, 640) self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomCrop( size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant" ), transforms.RandomAffine(**random_affine_kwargs), transforms.RandomPerspective(**random_perspective_kwargs), transforms.GaussianBlur(**gaussian_blur_kwargs), transforms.RandomAdjustSharpness(**sharpness_kwargs), ] ) ================================================ FILE: lab04/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab04/training/__init__.py ================================================ ================================================ FILE: lab04/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import callbacks as cb from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--wandb", action="store_true", default=False, help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.", ) parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.loss == "transformer": lit_model_class = lit_models.TransformerLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" if goldstar_metric == "validation/cer": filename_format += "-validation.cer={validation/cer:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.wandb: logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train") logger.watch(model, log_freq=max(100, args.log_every_n_steps)) logger.log_hyperparams(vars(args)) experiment_dir = logger.experiment.dir callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) if args.wandb and args.loss in ("transformer",): callbacks.append(cb.ImageToTextLogger()) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") if args.wandb: rank_zero_info("Best model also uploaded to W&B ") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab04/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: lab05/.flake8 ================================================ [flake8] select = ANN,B,B9,BLK,C,D,E,F,I,S,W # only check selected error codes max-complexity = 12 # C9 - flake8 McCabe Complexity checker -- threshold max-line-length = 120 # E501 - flake8 -- line length too long, actually handled by black extend-ignore = # E W - flake8 PEP style check E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks # S - flake8-bandit safety check S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password # ANN - flake8-annotations type annotation check ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some # D1 - flake8-docstrings docstring style check D100,D102,D103,D104,D105, # missing docstrings # D2 D4 - flake8-docstrings docstring style check D200,D205,D400,D401, # whitespace issues and first line content # DAR - flake8-darglint docstring correctness check DAR103, # mismatched or missing type in docstring application-import-names = app_gradio,text_recognizer,tests,training # flake8-import-order: which names are first party? import-order-style = google # flake8-import-order: which import order style guide do we use? docstring-convention = numpy # flake8-docstrings: which docstring style guide do we use? strictness = short # darglint: how "strict" are we with docstring completeness? docstring-style = numpy # darglint: which docstring style guide do we use? suppress-none-returning = true # flake8-annotations: do we allow un-annotated Nones in returns? mypy-init-return = true # flake8-annotations: do we allow init to have no return annotation? per-file-ignores = # list of case-by-case ignores, see files for details */__init__.py:F401,I */data/*.py:DAR data/*.py:F,I *text_recognizer/util.py:DAR101,F401 *training/run_experiment.py:I202 *app_gradio/app.py:I202 ================================================ FILE: lab05/.github/workflows/pre-commit.yml ================================================ name: pre-commit on: pull_request: push: # allows this Action to be triggered manually workflow_dispatch: jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: python-version: '3.10' - uses: pre-commit/action@v3.0.0 ================================================ FILE: lab05/.pre-commit-config.yaml ================================================ repos: # a set of useful Python-based pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: # list of definitions and supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace # removes any whitespace at the ends of lines - id: check-toml # check toml syntax by loading all toml files - id: check-yaml # check yaml syntax by loading all yaml files - id: check-json # check-json syntax by loading all json files - id: check-merge-conflict # check for files with merge conflict strings args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge - id: check-added-large-files # check that no "large" files have been added args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server - id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.) - id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.) # black python autoformatting - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black # additional configuration of black in pyproject.toml # flake8 python linter with all the fixins - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08) additional_dependencies: [ flake8-bandit, flake8-bugbear, flake8-docstrings, flake8-import-order, darglint, mypy, pycodestyle, pydocstyle] args: ["--config", ".flake8"] # additional configuration of flake8 and extensions in .flake8 # shellcheck-py for linting shell files - repo: https://github.com/shellcheck-py/shellcheck-py rev: v0.8.0.4 hooks: - id: shellcheck ================================================ FILE: lab05/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab05/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab05/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab05/notebooks/lab03_transformers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 03: Transformers and Paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The fundamental reasons why the Transformer is such\n", "a powerful and popular architecture\n", "- Core intuitions for the behavior of Transformer architectures\n", "- How to use a convolutional encoder and a Transformer decoder to recognize\n", "entire paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 3\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Transformers?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal in building a text recognizer is to take a two-dimensional image\n", "and convert it into a one-dimensional sequence of characters\n", "from some alphabet." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional neural networks,\n", "discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n", "are great at encoding images,\n", "taking them from their raw pixel values\n", "to a more semantically meaningful numerical representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But how do we go from that to a sequence of letters?\n", "And what's especially tricky:\n", "the number of letters in an image is separable from its size.\n", "A screenshot of this document has a much higher density of letters\n", "than a close-up photograph of a piece of paper.\n", "How do we get a _variable-length_ sequence of letters,\n", "where the length need have nothing to do with the size of the input tensor?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n", "they were\n", "[originally introduced](https://arxiv.org/abs/1706.03762)\n", "for transforming one sequence into another,\n", "as in machine translation.\n", "This makes them a natural fit for processing language.\n", "\n", "But they have also found success in other domains --\n", "at the time of this writing, large transformers\n", "dominate the\n", "[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n", "that has become a de facto standard for comparing models\n", "and are finding\n", "[application in reinforcement learning](https://arxiv.org/abs/2106.01345)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we will use a Transformer as a key component of our final architecture:\n", "we will encode our input images with a CNN\n", "and then read them out into a text sequence with a Transformer.\n", "\n", "Before trying out this new model,\n", "let's first get an understanding of why the Transformer architecture\n", "has become so popular by walking through its history\n", "and then get some intuition for how it works\n", "by looking at some\n", "[recent work](https://transformer-circuits.pub/)\n", "on explaining the behavior of both toy models and state-of-the-art language models." ] }, { "cell_type": "markdown", "metadata": { "id": "kmKqjbvd-Mj3" }, "source": [ "## Why not convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "SRqkUMdM-OxU" }, "source": [ "In the ancient beforetimes (i.e. 2016),\n", "the best models for natural language processing were all\n", "_recurrent_ neural networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional networks were also occasionally used,\n", "but they suffered from a serious issue:\n", "their architectural biases don't fit text.\n", "\n", "First, _translation equivariance_ no longer holds.\n", "The beginning of a piece of text is often quite different from the middle,\n", "so the absolute position matters.\n", "\n", "Second, _locality_ is not as important in language.\n", "The name of a character that hasn't appeared in thousands of pages\n", "can become salient when someone asks, \"Whatever happened to\n", "[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n", "\n", "Consider interpreting a piece of text like the Python code below:\n", "```python\n", "def do(arg1, arg2, arg3):\n", " a = arg1 + arg2\n", " b = arg3[:3]\n", " c = a * b\n", " return c\n", "\n", "print(do(1, 1, \"ayy lmao\"))\n", "```\n", "\n", "After a `(` we expect a `)`,\n", "but possibly very long afterwards,\n", "[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n", "and similarly we expect a `]` at some point after a `[`.\n", "\n", "For translation variance, consider\n", "that we interpret `*` not by\n", "comparing it to its neighbors\n", "but by looking at `a` and `b`.\n", "We mix knowledge learned through experience\n", "with new facts learned while reading --\n", "also known as _in-context learning_.\n", "\n", "In a longer text,\n", "[e.g. the one you are reading now](./lab03_transformers.ipynb),\n", "the translation variance of text is clearer.\n", "Every lab notebook begins with the same header,\n", "setting up the environment,\n", "but that header never appears elsewhere in the notebook.\n", "Later positions need to be processed in terms of the previous entries.\n", "\n", "Unlike an image, we cannot simply rotate or translate our \"camera\"\n", "and get a new valid text.\n", "[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n", "that can be read without regard to position." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The field of formal language theory,\n", "which has deep mutual influence with computer science,\n", "gives one way of explaining the issues with convolutional networks:\n", "they can only understand languages with _finite contexts_,\n", "where all the information can be found within a finite window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The immediate solution, drawing from the connections to computer science, is\n", "[recursion](https://www.google.com/search?q=recursion).\n", "A network whose output on the final entry of the sequence is a recursive function\n", "of all the previous entries can build up knowledge\n", "as it reads the sequence and treat early entries quite differently than it does late ones." ] }, { "cell_type": "markdown", "metadata": { "id": "aa6cbTlImkEh" }, "source": [ "In pseudo-code, such a _recurrent neural network_ module might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "lKtBoPnglPrW" }, "source": [ "```python\n", "def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n", " next_inputs = input_module(xs[-1])\n", " next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n", " return output_module(next_inputs, next_hiddens)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "IbJPSMnEm516" }, "source": [ "If you've had formal computer science training,\n", "then you may be familiar with the power of recursion,\n", "e.g. the\n", "[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n", "that gave its name to the now much better-known\n", "[startup incubator](https://www.ycombinator.com/).\n", "\n", "The particular form of recursion used by\n", "recurrent neural networks implements a\n", "[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n", "\n", "> If you've know a lot of computer science,\n", "you might be concerned by this connection.\n", "What about other\n", "[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n", "Where are the neural network architectures for differentiable\n", "[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n", "Check out Graph Neural Networks,\n", "[which implement dynamic programming](https://arxiv.org/abs/2203.15544)." ] }, { "cell_type": "markdown", "metadata": { "id": "63mMTbEBpVuE" }, "source": [ "Recurrent networks are able to achieve\n", "[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n", "\n", "There are many popular recurrent architectures,\n", "from the beefy and classic\n", "[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n", "and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n", "([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n", "all of which have roughly similar capabilities but\n", "[some of which are easier to train](https://arxiv.org/abs/1611.09913)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwQHVTIslOku" }, "source": [ "In the same sense that MLPs can model \"any\" feedforward function,\n", "in principle even basic RNNs\n", "[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n", "\n", "In particular they can model any\n", "[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n", "which is a formal way of saying that they can in principle\n", "do anything a computer is capable of doing.\n", "\n", "The question is then..." ] }, { "cell_type": "markdown", "metadata": { "id": "3J8EoGN3pu7P" }, "source": [ "## Why aren't we all using RNNs?" ] }, { "cell_type": "markdown", "metadata": { "id": "TDwNWaevpt_3" }, "source": [ "The guarantees that MLPs can model any function\n", "or that RNNs can model Turing machines\n", "provide decent intuition but are not directly practically useful.\n", "Among other reasons, they don't guarantee learnability --\n", "that starting from random parameters we can find the parameters\n", "that implement a given function.\n", "The\n", "[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n", "than would seem from basic theoretical and empirical analysis.\n", "\n", "One way of understanding capacity to model language is\n", "[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n", "In this model of formal languages,\n", "Turing machines sit at the top\n", "([practically speaking](https://arxiv.org/abs/math/0209332)).\n", "\n", "With better mathematical models,\n", "RNNs and LSTMs can be shown to be\n", "[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n", "with RNNs looking more like\n", "[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n", "and LSTMs coming in\n", "[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n", "\n", "More controversially:\n", "the Chomsky hierarchy is great for understanding syntax and grammar,\n", "which makes it great for building parsers\n", "and working with formal languages,\n", "but the goal in _natural_ language processing is to understand _natural_ language.\n", "Most humans' natural language is far from strictly grammatical,\n", "but that doesn't mean it is nonsense.\n", "\n", "And to really \"understand\" language means\n", "to understand its semantic content, which is fuzzy.\n", "The most important thing for handling the fuzzy semantic content\n", "of language is not whether you can recall\n", "[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n", "but whether you can model probabilistic relationships between concepts\n", "in addition to grammar and syntax." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These both leave theoretical room for improvement over current recurrent\n", "language and sequence models.\n", "\n", "But the real cause of the rise of Transformers is that..." ] }, { "cell_type": "markdown", "metadata": { "id": "Dsu1ebvAp-3Z" }, "source": [ "## Transformers are designed to train fast at scale on contemporary hardware." ] }, { "cell_type": "markdown", "metadata": { "id": "c4abU5adsPGs" }, "source": [ "The Transformer architecture has several important features,\n", "discussed below,\n", "but one of the most important reasons why it is successful\n", "is because it can be more easily trained at scale.\n", "\n", "This scalability is the focus of the discussion in the paper\n", "that introduced the architecture,\n", "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", "and\n", "[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n", "\n", "The recursion in RNNs is inherently sequential:\n", "the dependence on the outputs from earlier in the sequence\n", "means computations within an example cannot be parallelized.\n", "\n", "So RNNs must batch across examples to scale,\n", "but as sequence length grows this hits memorybandwidth limits.\n", "Serving up large batches quickly with good randomness guarantees\n", "is also hard to optimize,\n", "especially in distributed settings.\n", "\n", "The Transformer architecture,\n", "on the other hand,\n", "can be readily parallelized within a single example sequence,\n", "in addition to parallelization across batches.\n", "This can lead to massive performance gains for a fixed scale,\n", "which means larger, higher capacity models\n", "can be trained on larger datasets." ] }, { "cell_type": "markdown", "metadata": { "id": "_Mzk2haFC_G1" }, "source": [ "How does the architecture achieve this parallelizability?\n", "\n", "Let's start with the architecture diagram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u59eu4snLQfp" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n", "\n", "display.Image(url=base_url + \"/aiayn-figure-1.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ez-XEQ7M0UlR" }, "source": [ "> To head off a bit of confusion\n", " in case you've worked with Transformer architectures before:\n", " the original \"Transformer\" is an encoder/decoder architecture.\n", " Many LLMs, like GPT models, are decoder only,\n", " because this has turned out to scale well,\n", " and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n", " it's all text anyways.\n", " We, however, will be using them across modalities,\n", " so we need an explicit encoder,\n", " as above. " ] }, { "cell_type": "markdown", "metadata": { "id": "ok4ksBi4vp89" }, "source": [ "First focusing on the encoder (left):\n", "the encoding at a given position is a function of all previous inputs.\n", "But it is not a function of the previous _encodings_:\n", "we produce the encodings \"all at once\"." ] }, { "cell_type": "markdown", "metadata": { "id": "RPN7C-_OqzHP" }, "source": [ "The decoder (right) does use previous \"outputs\" as its inputs,\n", "but those outputs are not the vectors of layer activations\n", "(aka embeddings)\n", "that are produced by the network.\n", "They are instead the processed outputs,\n", "after a `softmax` and an `argmax`.\n", "\n", "We could obtain these outputs by processing the embeddings,\n", "much like in a recurrent architecture.\n", "In fact, that is one way that Transformers are run.\n", "It's what happens in the `.forward` method\n", "of the model we'll be training for character recognition:\n", "`ResnetTransformer`." ] }, { "cell_type": "markdown", "metadata": { "id": "L5_2WMmtDnJn" }, "source": [ "Let's look at that forward method\n", "and connect it to the diagram." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FR5pk4kEyCGg" }, "outputs": [], "source": [ "from text_recognizer.models import ResnetTransformer\n", "\n", "\n", "ResnetTransformer.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "-J5UFDoPzPbq" }, "source": [ "`.encode` happens first -- that's the left side of diagram.\n", "\n", "The encoder can in principle be anything\n", "that produces a sequence of fixed-length vectors,\n", "but here it's\n", "[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n", "\n", "Then we start iterating over the sequence\n", "in the `for` loop.\n", "\n", "Focus on the first few lines of code.\n", "We apply `.decode` (right side of diagram)\n", "to the outputs so far.\n", "\n", "Once we have a new `output`, we apply `.argmax`\n", "to turn the logits into a concrete prediction of\n", "a particular token.\n", "\n", "This is added as the last output token\n", "and then the loop happens again." ] }, { "cell_type": "markdown", "metadata": { "id": "LTcy8-rV1dHr" }, "source": [ "Run this way, our model looks very much like a recurrent architecture:\n", "we call the model on its own outputs\n", "to generate the next value.\n", "These types of models are also referred to as\n", "[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n", "because we predict (as we do in _regression_)\n", "the next value based on our own (_auto_) output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But Transformers are designed to be _trained_ more scalably than RNNs,\n", "not necessarily to _run inference_ more scalably,\n", "and it's actually not the case that our model's `.forward` is called during training." ] }, { "cell_type": "markdown", "metadata": { "id": "eCxMSAWmEKBt" }, "source": [ "Let's look at what happens during training\n", "by checking the `training_step`\n", "of the `LightningModule`\n", "we use to train our Transformer models,\n", "the `TransformerLitModel`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0o7q8N7P2w4H" }, "outputs": [], "source": [ "from text_recognizer.lit_models import TransformerLitModel\n", "\n", "TransformerLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "1VgNNOjvzC4y" }, "source": [ "Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`." ] }, { "cell_type": "markdown", "metadata": { "id": "tz-6NGPR4dUr" }, "source": [ "Let's look at `.teacher_forward`,\n", "and in particular its type signature:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ILc2oWET4i2Z" }, "outputs": [], "source": [ "TransformerLitModel.teacher_forward??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`." ] }, { "cell_type": "markdown", "metadata": { "id": "lf32lpgrDb__" }, "source": [ "This is known as \"teacher forcing\".\n", "The \"teacher\" signal is \"forcing\"\n", "the model to behave as though\n", "it got the answer right.\n", "\n", "[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n", "It's more effective here\n", "because the right teaching signal\n", "for our network is the target data,\n", "which we have access to during training,\n", "whereas in an RNN the best teaching signal\n", "would be the target embedding vector,\n", "which we do not know.\n", "\n", "During inference, when we don't have access to the ground truth,\n", "we revert to the autoregressive `.forward` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This \"trick\" allows Transformer architectures to readily scale\n", "up models to the parameter counts\n", "[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)." ] }, { "cell_type": "markdown", "metadata": { "id": "BAjqpJm9uUuU" }, "source": [ "## Is there more to Transformers more than just a training trick?" ] }, { "cell_type": "markdown", "metadata": { "id": "kWCYXeHv7Qc9" }, "source": [ "[Very](https://arxiv.org/abs/2005.14165),\n", "[very](https://arxiv.org/abs/1909.08053),\n", "[very](https://arxiv.org/abs/2205.01068)\n", "large Transformer models have powered the most recent wave of exciting results in ML, like\n", "[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n", "\n", "They are also the first machine learning models to have come anywhere close to\n", "deserving the term _artificial intelligence_ --\n", "a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is surprising because the models and their training procedure are\n", "(relatively speaking)\n", "pretty _simple_,\n", "even if it doesn't feel that way on first pass." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic Transformer architecture is just a bunch of\n", "dense matrix multiplications and non-linearities --\n", "it's perhaps simpler than a convolutional architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And advances since the introduction of Transformers in 2017\n", "have not in the main been made by\n", "creating more sophisticated model architectures\n", "but by increasing the scale of the base architecture,\n", "or if anything making it simpler, as in\n", "[GPT-type models](https://arxiv.org/abs/2005.14165),\n", "which drop the encoder." ] }, { "cell_type": "markdown", "metadata": { "id": "V1HQS9ey8GMc" }, "source": [ "These models are also trained on very simple tasks:\n", "most LLMs are just trying to predict the next element in the sequence,\n", "given the previous elements --\n", "a task simple enough that Claude Shannon,\n", "father of information theory, was\n", "[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n", "\n", "These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n", "e.g. by scraping the web." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They are also trained in a simple fashion:\n", "first-order stochastic optimizers, like SGD or an\n", "[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n", "intended for the most basic of optimization problems,\n", "that scale more readily than the second-order optimizers\n", "that dominate other areas of optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "Kz9HPDoy7OAl" }, "source": [ "This is\n", "[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n", "of work in ML:\n", "simple, even seemingly wasteful,\n", "architectures that scale well and are robust\n", "to implementation details\n", "eventually outstrip more clever but\n", "also more finicky approaches that are harder to scale.\n", "This lesson has led some to declare that\n", "[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n", "in machine learning, and perhaps even in artificial intelligence." ] }, { "cell_type": "markdown", "metadata": { "id": "SdN9o2Y771YZ" }, "source": [ "> That is not to say that because the algorithms are relatively simple,\n", " training a model at this scale is _easy_ --\n", " [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n", " [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n", " [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n", " But choosing the simplest algorithm at every step makes solving the scaling problem feasible." ] }, { "cell_type": "markdown", "metadata": { "id": "baVGf6gKFOvs" }, "source": [ "The importance of scale is the key lesson from the Transformer architecture,\n", "far more than any theoretical considerations\n", "or any of the implementation details.\n", "\n", "That said, these large Transformer models are capable of\n", "impressive behaviors and understanding how they achieve them\n", "is of intellectual interest.\n", "Furthermore, like any architecture,\n", "there are common failure modes,\n", "of the model and of the modelers who use them,\n", "that need to be taken into account." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "Below, we'll cover two key intuitions about Transformers:\n", "Transformers are _residual_, like ResNets,\n", "and they compose _low rank_ sequence transformations.\n", "Together, this means they act somewhat like a computer,\n", "reading from and writing to a \"tape\" or memory\n", "with a sequence of simple instructions." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "We'll also cover a surprising implementation detail:\n", "despite being commonly used for sequence modeling,\n", "by default the architecture is _position insensitive_." ] }, { "cell_type": "markdown", "metadata": { "id": "uni0VTCr9lev" }, "source": [ "### Intuition #1: Transformers are highly residual." ] }, { "cell_type": "markdown", "metadata": { "id": "0MoBt-JLJz-d" }, "source": [ "> The discussion of these inuitions summarizes the discussion in\n", "[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n", "from\n", "[Anthropic](https://www.anthropic.com/),\n", "an AI safety and research company.\n", "The figures below are from that blog post.\n", "It is the spiritual successor to the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "covered in\n", "[Lab 02b](https://lab02b-colab).\n", "If you want to truly understand Transformers,\n", "we highly recommend you check it out,\n", "including the\n", "[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "UUbNVvM5Ferm" }, "source": [ "It's easy to see that ResNets are residual --\n", "it's in the name, after all.\n", "\n", "But Transformers are,\n", "in some sense,\n", "even more closely tied to residual computation\n", "than are ResNets:\n", "ResNets and related architectures include downsampling,\n", "so there is not a direct path from inputs to outputs.\n", "\n", "In Transformers, the exact same shape is maintained\n", "from the moment tokens are embedded,\n", "through dozens or hundreds of intermediate layers,\n", "and until they are \"unembedded\" into class logits.\n", "The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n", "\n", "The resiudal stream is easy to see with a change of perspective.\n", "Instead of the usual architecture diagram above,\n", "which emphasizes the layers acting on the tensors,\n", "consider this alternative view,\n", "which emphasizes the tensors as they pass through the layers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HRMlVguKKW6y" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-residual-view.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9K3N7ilVkB3" }, "source": [ "For definitions of variables and terms, see the\n", "[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)." ] }, { "cell_type": "markdown", "metadata": { "id": "arvciE-kKd_L" }, "source": [ "Note that this is a _decoder-only_ Transformer architecture --\n", "so it should be compared with the right-hand side of the original architecture diagram above." ] }, { "cell_type": "markdown", "metadata": { "id": "wvrRMd_RKp_G" }, "source": [ "Notice that outputs of the attention blocks \n", "and of the MLP layers are\n", "added to their inputs, as in a ResNet.\n", "These operations are represented as \"Add & Norm\" layers in the classical diagram;\n", "normalization is ignored here for simplicity." ] }, { "cell_type": "markdown", "metadata": { "id": "o8n_iT-FFAbK" }, "source": [ "This total commitment to residual operations\n", "means the size of the embeddings\n", "(referred to as the \"model dimension\" or the \"embedding dimension\",\n", "here and below `d_model`)\n", "stays the same throughout the entire network.\n", "\n", "That means, for example,\n", "that the output of each layer can be used as input to the \"unembedding\" layer\n", "that produces logits.\n", "We can read out the computations of intermediate layers\n", "just by passing them through the unembedding layer\n", "and examining the logit tensor.\n", "See\n", "[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n", "for detailed experiments and interactive notebooks.\n", "\n", "In short, we observe a sort of \"progressive refinement\"\n", "of the next-token prediction\n", "as the embeddings proceed, depthwise, through the network." ] }, { "cell_type": "markdown", "metadata": { "id": "Ovh_3YgY9z2h" }, "source": [ "### Intuition #2 Transformer heads learn low rank transformations." ] }, { "cell_type": "markdown", "metadata": { "id": "XpNmozlnOdPC" }, "source": [ "In the original paper and in\n", "most presentations of Transformers,\n", "the attention layer is written like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PA7me8gNP5LE" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In pseudo-typed PyTorch (based loosely on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n", "that looks like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Oeict_6wGJgD" }, "source": [ "```python\n", "def classic_attention(\n", " Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " K: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " return torch.softmax(Q @ K.T) @ V\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "8pewU90DSuOR" }, "source": [ "This is effectively exactly\n", "how it is written\n", "in PyTorch,\n", "apart from implementation details\n", "(look for `bmm` for the matrix multiplications and a `softmax` call):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrgTpKFvOhwc" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "F._scaled_dot_product_attention??" ] }, { "cell_type": "markdown", "metadata": { "id": "ebDXZ0tlSe7g" }, "source": [ "But the best way to write an operation so that a computer can execute it quickly\n", "is not necessarily the best way to write it so that a human can understand it --\n", "otherwise we'd all be coding in assembly.\n", "\n", "And this is a strange way to write it --\n", "you'll notice that what we normally think of\n", "as the \"inputs\" to the layer are not shown.\n", "\n", "We can instead write out the attention layer\n", "as a function of the inputs $x$.\n", "We write it for a single \"attention head\".\n", "Each attention layer includes a number of heads\n", "that read and write from the residual stream\n", "simultaneously and independently.\n", "We also add the output layer weights $W_O$\n", "and we get:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LuFNR67tQpsf" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SVnBjjfOLwxP" }, "source": [ "or, in pseudo-typed PyTorch:" ] }, { "cell_type": "markdown", "metadata": { "id": "LmpOm-HfGaNz" }, "source": [ "```python\n", "def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n", " key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n", " key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n", " # maps queries of residual stream to keys from residual stream, independent of position\n", "\n", " value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n", " output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n", " value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n", " # transformation applied to each token, regardless of position\n", "\n", " attention_logits = x.T @ key_query_circuit @ x\n", " attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n", " # maps positions to positions, often very sparse\n", "\n", " value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n", "\n", " return attention_map @ value_output # transformed tokens filtered by attention map\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "dC0eqxZ6UAGT" }, "source": [ "Consider the `key_query_circuit`\n", "and `value_output_circuit`\n", "matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n", "\n", "The key/query dimension, `d_head`\n", "is small relative to the model's dimension, `d_model`,\n", "so $W_{QK}$ and $W_{OV}$ are very low rank,\n", "[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n", "that they factorize into two matrices,\n", "one with a smaller number of rows\n", "and another with a smaller number of columns.\n", "That number is called the _rank_.\n", "\n", "When computing, these matrices are better represented via their components,\n", "rather than computed directly,\n", "which leads to the normal implementation of attention.\n", "\n", "In a large language model,\n", "the ratio of residual stream dimension, `d_model`, to\n", "the dimension of a single head, `d_head`, is huge, often 100:1.\n", "That means each query, key, and value computed at a position\n", "is a fairly simple, low-dimensional feature of the residual stream at that position.\n", "\n", "For visual intuition,\n", "we compare what a matrix with a rank 100th of full rank looks like,\n", "relative to a full rank matrix of the same size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LUbojJMiW2C" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "\n", "low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n", "full_rank = torch.randn(100, 100)\n", "plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n", "plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");" ] }, { "cell_type": "markdown", "metadata": { "id": "lqBst92-OVka" }, "source": [ "The pattern in the first matrix is very simple,\n", "relative to the pattern in the second matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "SkCGrs9EiVh4" }, "source": [ "Another feature of low rank transformations is\n", "that they have a large nullspace or kernel --\n", "these are directions we can move the input without changing the output.\n", "\n", "That means that many changes to the residual stream won't affect the behavior of this head at all." ] }, { "cell_type": "markdown", "metadata": { "id": "UVz2dQgzhD4p" }, "source": [ "### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)." ] }, { "cell_type": "markdown", "metadata": { "id": "hVlzwR03m8mC" }, "source": [ "The combination of residuality\n", "(changes are added to the current input)\n", "and low rank\n", "(only a small subspace is changed by each head)\n", "drastically changes the intuition about Transformers." ] }, { "cell_type": "markdown", "metadata": { "id": "qqjZI2jKe6HH" }, "source": [ "Rather than being an \"embedding of a token in its context\",\n", "the residual stream becomes something more like a memory or a scratchpad:\n", "one layer reads a small bit of information from the stream\n", "and writes a small bit of information back to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YIBkxlqepjc" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-layer-residual.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtsKhkLfk00l" }, "source": [ "The residual stream works like a memory because it is roomy enough\n", "that these actions need not interfere:\n", "the subspaces targeted by reads and writes are small relative to the ambient space,\n", "so they can\n", "\n", "Additionally, the dimension of each head is still in the 100s in large models,\n", "and\n", "[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n", "in them, so the number of effectively degrees of freedom is\n", "actually larger than the dimension.\n", "This phenomenon allows high-dimensional tensors to serve as\n", "[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n", "There are\n", "[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n", "\n", "Together, this means an early layer can write information to the stream\n", "that can be used by later layers -- by many of them at once, possibly much later.\n", "Later layers can learn to edit this information,\n", "e.g. deleting it,\n", "if doing so reduces the loss,\n", "but by default the information is preserved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EragIygzJg86" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-stream-read-write.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "oKIaUZjwkpW7" }, "source": [ "Lastly, the softmax in the attention has a sparsifying effect,\n", "and so many attention heads are reading from \n", "just one token and writing to just one other token." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dN6VcJqIMKnB" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-token-to-token.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Repeatedly reading information from an external memory\n", "and using it to decide which operation to perform\n", "and where to write the results\n", "is at the core of the\n", "[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n", "For a concrete example, the\n", "[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n", "includes a dissection of a form of \"pointer arithmetic\"\n", "that appears in some models." ] }, { "cell_type": "markdown", "metadata": { "id": "0kLFh7Mvnolr" }, "source": [ "This point of view seems\n", "very promising for explaining numerous\n", "otherwise perhaps counterintuitive features of Transformer models.\n", "\n", "- This framework predicts lots that Transformers will readily copy-and-paste information,\n", "which might explain phenomena like\n", "[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n", "\n", "- It also readily explains\n", "[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n", "an important component of why Transformers perform well on medium-length texts\n", "and in few-shot learning.\n", "\n", "- Transformers also perform better on reasoning tasks when the text\n", "[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n", "is added to their input prompt.\n", "This is partly due to the fact that that prompt is associated,\n", "in the dataset, with clearer reasoning,\n", "and since the models are trained to predict which tokens tend to appear\n", "after an input, they tend to produce better reasoning with that prompt --\n", "an explanation purely in terms of sequence modeling.\n", "But it also gives the Transformer license to generate a large number of tokens\n", "that act to store intermediate information,\n", "making for a richer residual stream\n", "for reading and writing." ] }, { "cell_type": "markdown", "metadata": { "id": "RyLRzgG-93yB" }, "source": [ "### Implementation detail: Transformers are position-insensitive by default." ] }, { "cell_type": "markdown", "metadata": { "id": "oR6PnrlA_hJ2" }, "source": [ "In the attention calculation\n", "each token can query each other token,\n", "with no regard for order.\n", "Furthermore, the construction of queries, keys, and values\n", "is based on the content of the embedding vector,\n", "which does not automatically include its position.\n", "\"dog bites man\" and \"man bites dog\" are identical, as in\n", "[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n", "\n", "For most sequences,\n", "this is unacceptable:\n", "absolute and relative position matter\n", "and we cannot use the future to predict the past.\n", "\n", "We need to add two pieces to get a Transformer architecture that's usable for next-token prediction." ] }, { "cell_type": "markdown", "metadata": { "id": "EWHxGJz2-6ZK" }, "source": [ "First, the simpler piece:\n", "\"causal\" attention,\n", "so-named because it ensures that values earlier in the sequence\n", "are not influenced by later values, which would\n", "[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)." ] }, { "cell_type": "markdown", "metadata": { "id": "0c42xi6URYB4" }, "source": [ "The most common solution is straightforward:\n", "we calculate attention between all tokens,\n", "then throw out non-causal values by \"masking\" them\n", "(this is before applying the softmax,\n", "so masking means adding $-\\infty$).\n", "\n", "This feels wasteful --\n", "why are we calculating values we don't need?\n", "Trying to be smarter would be harder,\n", "and might rely on operations that aren't as optimized as\n", "matrix multiplication and addition.\n", "Furthermore, it's \"only\" twice as many operations,\n", "so it doesn't even show up in $O$-notation.\n", "\n", "A sample attention mask generated by our code base is shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NXaWe6pT-9jV" }, "outputs": [], "source": [ "from text_recognizer.models import transformer_util\n", "\n", "\n", "attention_mask = transformer_util.generate_square_subsequent_mask(100)\n", "\n", "ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n", "plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n", "print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This solves our causality problem,\n", "but we still don't have positional information." ] }, { "cell_type": "markdown", "metadata": { "id": "ZamUE4WIoGS2" }, "source": [ "The standard technique\n", "is to add alternating sines and cosines\n", "of increasing frequency to the embeddings\n", "(there are\n", "[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n", "most notably\n", "[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n", "Each position in the sequence is then uniquely identifiable\n", "from the pattern of these values.\n", "\n", "> Furthermore, for the same reason that\n", " [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n", " translations, e.g. relative positions, are fairly easy to express as linear transformations\n", " of sines and cosines)." ] }, { "cell_type": "markdown", "metadata": { "id": "IDG2uOsaELU0" }, "source": [ "We superimpose this positional information on our embeddings.\n", "Note that because the model is residual,\n", "this position information will be by default preserved\n", "as it passes through the network,\n", "so it doesn't need to be repeatedly added." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what this positional encoding looks like in our codebase:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Zk62Q-a-1Ax" }, "outputs": [], "source": [ "PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n", "\n", "pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n", "\n", "ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n", "print(pe[:4, :8])" ] }, { "cell_type": "markdown", "metadata": { "id": "ep2ClIWvqDms" }, "source": [ "When we add the positional information to our embeddings,\n", "both the embedding information and the positional information\n", "is approximately preserved,\n", "as can be visually assessed below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PJuFjoCzC0Y4" }, "outputs": [], "source": [ "fake_embeddings = torch.randn_like(pe) * 0.5\n", "\n", "ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n", "\n", "fake_embeddings_with_pe = fake_embeddings + pe\n", "\n", "plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIzBxDkEmH8" }, "source": [ "A [similar technique](https://arxiv.org/abs/2103.06450)\n", "is used to also incorporate positional information into the image embeddings,\n", "which are flattened before being fed to the decoder." ] }, { "cell_type": "markdown", "metadata": { "id": "HC1N85wl8dvn" }, "source": [ "### Learn more about Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "lJwYxkjTk15t" }, "source": [ "We're only able to give a flavor and an intuition for Transformers here.\n", "\n", "To improve your grasp on the nuts and bolts, check out the\n", "[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n", "which is surprisingly approachable,\n", "as far as ML research papers go.\n", "The\n", "[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n", "adds code and commentary to the original paper,\n", "which makes it even more digestible.\n", "For something even friendlier, check out the\n", "[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n", "by Jay Alammar, which has an accompanying\n", "[video](https://youtu.be/-QH8fRhqFHM).\n", "\n", "Anthropic's work on\n", "[Transformer Circuits](https://transformer-circuits.pub/),\n", "summarized above, has some of the best material\n", "for building theoretical understanding\n", "and is still being updated with extensions and applications of the framework.\n", "The\n", "[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n", "are a great aid for checking and building your understanding.\n", "\n", "But they are fairly math-heavy.\n", "If you have more of a software engineering background, see\n", "Transformer Circuits co-author Nelson Elhage's blog post\n", "[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n", "\n", "For a gentler introduction to the intuition for Transformers,\n", "check out Brandon Rohrer's\n", "[Transformers From Scratch](https://e2eml.school/transformers.html)\n", "tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "qg7zntJES-aT" }, "source": [ "An aside:\n", "the matrix multiplications inside attention dominate\n", "the big-$O$ runtime of Transformers.\n", "So trying to make the attention mechanism more efficient, e.g. linear time,\n", "has generated a lot of research\n", "(review paper\n", "[here](https://arxiv.org/abs/2009.06732)).\n", "Despite drawing a lot of attention, so to speak,\n", "at the time of writing in mid-2022, these methods\n", "[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n", "so it isn't likely to be worth the effort to spend time learning about them\n", "unless you are a Transformer specialist." ] }, { "cell_type": "markdown", "metadata": { "id": "vCjXysEJ8g9_" }, "source": [ "# Using Transformers to read paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "KsfKWnOvqjva" }, "source": [ "Our simple convolutional model for text recognition from\n", "[Lab 02b](https://fsdl.me/lab02b-colab)\n", "could only handle cleanly-separated characters.\n", "\n", "It worked by sliding a LeNet-style CNN\n", "over the image,\n", "predicting a character for each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njLdzBqy-I90" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist_lines = text_recognizer.data.EMNISTLines()\n", "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "\n", "# for sliding, see the for loop over range(S)\n", "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "K0N6yDBQq8ns" }, "source": [ "But unfortunately for us, handwritten text\n", "doesn't come in neatly-separated characters\n", "of equal size, so we trained our model on synthetic data\n", "designed to work with that model." ] }, { "cell_type": "markdown", "metadata": { "id": "hiqUVbj0sxLr" }, "source": [ "Now that we have a better model,\n", "we can work with better data:\n", "paragraphs from the\n", "[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)." ] }, { "cell_type": "markdown", "metadata": { "id": "oizsOAcKs-dD" }, "source": [ "The cell uses our `LightningDataModule`\n", "to download and preprocess this data,\n", "writing results to disk.\n", "We can then spin up `DataLoader`s to give us batches.\n", "\n", "It can take several minutes to run the first time\n", "on commodity machines,\n", "with most time spent extracting the data.\n", "On subsequent runs,\n", "the time-consuming operations will not be repeated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uL9LHbjdsUbm" }, "outputs": [], "source": [ "iam_paragraphs = text_recognizer.data.IAMParagraphs()\n", "\n", "iam_paragraphs.prepare_data()\n", "iam_paragraphs.setup()\n", "xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n", "\n", "iam_paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "nBkFN9bbTm_S" }, "source": [ "Now that we've got a batch,\n", "let's take a look at some samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hqaps8yxtBhU" }, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import wandb\n", "\n", "\n", "def show(y):\n", " y = y.detach().cpu() # bring back from accelerator if it's being used\n", " return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"

\", \"\")\n", "\n", "idx = random.randint(0, len(xs))\n", "\n", "print(show(ys[idx]))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "4dT3UCNzTsoc" }, "source": [ "The `ResnetTransformer` model can run on this data\n", "if passed the `.config`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXL-vIGRr86D" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())" ] }, { "cell_type": "markdown", "metadata": { "id": "MMxa-oWyT01E" }, "source": [ "Our models are now big enough\n", "that we want to make use of GPU acceleration\n", "as much as we can,\n", "even when working on single inputs,\n", "so let's cast to the GPU if we have one." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YyUM8LgvW0w" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "rnt.to(device); xs = xs.to(device); ys = ys.to(device);" ] }, { "cell_type": "markdown", "metadata": { "id": "Y-E3UdD4zUJi" }, "source": [ "First, let's just pass it through the ResNet encoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-LUUtlvaxrvg" }, "outputs": [], "source": [ "resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n", " # resnet is designed for RGB images, so we replicate the input across channels 3 times" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eimgJ5dnywjg" }, "outputs": [], "source": [ "resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n", "plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n", "plt.axis(\"off\"); plt.colorbar(fraction=0.05);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These embeddings, though generated by random, untrained weights,\n", "are not entirely useless.\n", "\n", "Before neural networks could be effectively\n", "trained end to end,\n", "they were often used with frozen random weights\n", "eveywhere except the final layer\n", "(see e.g.\n", "[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n", "[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n", "these methods were still competitive, and\n", "[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n", "provide a\n", "[theoretical basis](https://arxiv.org/abs/2011.14522)\n", "for understanding their performance." ] }, { "cell_type": "markdown", "metadata": { "id": "ye6pW0ETzw2A" }, "source": [ "The final result, though, is repetitive gibberish --\n", "at the bare minimum, we need to train the unembedding/readout layer\n", "in order to get reasonable text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our architecture includes randomization with dropout,\n", "so repeated runs of the cell below will generate different outcomes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu3Pa7gLsFMo" }, "outputs": [], "source": [ "preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gvCXUbskv6XM" }, "outputs": [], "source": [ "print(show(preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without teacher forcing, runtime is also variable from iteration to iteration --\n", "the model stops when it generates an \"end sequence\" or padding token,\n", "which is not deterministic thanks to the dropout layers.\n", "For similar reasons, runtime is variable across inputs.\n", "\n", "The variable runtime of autoregressive generation\n", "is also not great for scaling.\n", "In a distributed setting, as required for large scale,\n", "forward passes need to be synced across devices,\n", "and if one device is generating a batch of much longer sequences,\n", "it will cause all the others to idle while they wait on it to finish." ] }, { "cell_type": "markdown", "metadata": { "id": "t76MSVRXV0V7" }, "source": [ "Let's turn our model into a `TransformerLitModel`\n", "so we can run with teacher forcing.\n", "\n", "> You may be wondering:\n", " why isn't teacher forcing part of the PyTorch module?\n", " In general, the `LightningModule`\n", " should encapsulate things that are needed in training, validation, and testing\n", " but not during inference.\n", " The teacher forcing trick fits this paradigm,\n", " even though it's so critical to what makes Transformers powerful. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8qrHRKHowdDi" }, "outputs": [], "source": [ "import text_recognizer.lit_models\n", "\n", "lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlNaFqR50Oid" }, "source": [ "Now we can use `.teacher_forward` if we also provide the target `ys`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpZdqXS5wn0F" }, "outputs": [], "source": [ "forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Zx9SmsN0QLT" }, "source": [ "This may not run faster than the `rnt.forward`,\n", "since generations are always the maximum possible length,\n", "but runtimes and output lengths are deterministic and constant." ] }, { "cell_type": "markdown", "metadata": { "id": "tu-XNYpi0Qvi" }, "source": [ "Forcing doesn't necessarily make our predictions better.\n", "They remain highly repetitive gibberish." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JcEgify9w0sv" }, "outputs": [], "source": [ "forcing_preds = torch.argmax(forcing_outs, dim=0)\n", "\n", "print(show(forcing_preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xn6GGNzc9a3o" }, "source": [ "## Training the `ResNetTransformer`" ] }, { "cell_type": "markdown", "metadata": { "id": "uvZYsuSyWUXe" }, "source": [ "We're finally ready to train this model on full paragraphs of handwritten text!" ] }, { "cell_type": "markdown", "metadata": { "id": "3cJwC7b720Sd" }, "source": [ "This is a more serious model --\n", "it's the one we use in the\n", "[deployed TextRecognizer application](http://fsdl.me/app).\n", "It's much larger than the models we've seen this far,\n", "so it can easily outstrip available compute resources,\n", "in particular GPU memory.\n", "\n", "To help, we use\n", "[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n", "which shrinks the size of most of our floats by half,\n", "which reduces memory consumption and can speed up computation.\n", "\n", "If your GPU has less than 8GB of available RAM,\n", "you'll see a \"CUDA out of memory\" `RuntimeError`,\n", "which is something of a\n", "[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n", "In this case, you can resolve it by reducing the `--batch_size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w1mXlhfy04Nm" }, "outputs": [], "source": [ "import torch\n", "\n", "gpus = int(torch.cuda.is_available())\n", "\n", "if gpus:\n", " !nvidia-smi\n", "else:\n", " print(\"watch out! working with this model on a typical CPU is not feasible\")" ] }, { "cell_type": "markdown", "metadata": { "id": "os1vW1rPZ1dy" }, "source": [ "Even with an okay GPU, like a\n", "[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n", "a single epoch of training can take over 10 minutes to run.\n", "We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n", "but you can remove those flags to see what full training looks like." ] }, { "cell_type": "markdown", "metadata": { "id": "vnF6dWFn4JlZ" }, "source": [ "It can take a long time (overnight)\n", "to train this model to decent performance on a single GPU,\n", "so we'll focus on other pieces for the exercises.\n", "\n", "> At the time of writing in mid-2022, the cheapest readily available option\n", "for training this model to decent performance on this dataset with this codebase\n", "comes out around $10, using\n", "[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n", "See, for example,\n", "[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n", "and associated experiment.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HufjdUZN0t4l", "scrolled": false }, "outputs": [], "source": [ "%%time\n", "# above %%magic times the cell, useful as a poor man's profiler\n", "\n", "%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n", " --gpus={gpus} --batch_size 16 --precision 16 \\\n", " --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2" ] }, { "cell_type": "markdown", "metadata": { "id": "L6fQ93ju3Iku" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "udb1Ekjx3L63" }, "source": [ "### 🌟 Try out gradient accumulation and other \"training tricks\"." ] }, { "cell_type": "markdown", "metadata": { "id": "kpqViB4p3Wfb" }, "source": [ "Larger batches are helpful not only for increasing parallelization\n", "and amortizing fixed costs\n", "but also for getting more reliable gradients.\n", "Larger batches give gradients with less noise\n", "and to a point, less gradient noise means faster convergence.\n", "\n", "But larger batches result in larger tensors,\n", "which take up more GPU memory,\n", "a resource that is tightly constrained\n", "and device-dependent.\n", "\n", "Does that mean we are limited in the quality of our gradients\n", "due to our machine size?\n", "\n", "Not entirely:\n", "look up the `--accumulate_grad_batches`\n", "argument to the `pl.Trainer`.\n", "You should be able to understand why\n", "it makes it possible to compute the same gradients\n", "you would find for a batch of size `k * N`\n", "on a machine that can only run batches up to size `N`.\n", "\n", "Accumulating gradients across batches is among the\n", "[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n", "Try some of them out!\n", "Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment." ] }, { "cell_type": "markdown", "metadata": { "id": "b2vtkmX830y3" }, "source": [ "### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n", "\n", "While training this model to actually fit the whole dataset is infeasible\n", "as a short exercise on commodity hardware,\n", "it's practical to train this model to memorize a batch of 16 examples.\n", "\n", "Passing `--overfit_batches 1` flag limits the number of training batches to 1\n", "and turns off\n", "[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n", "so that in each epoch, the model just sees the same single batch of data over and over again.\n", "\n", "At first, try training the model to a loss of `2.5` --\n", "it should be doable in 100 epochs or less,\n", "which is just a few minutes on a commodity GPU.\n", "\n", "Once you've got that working,\n", "crank up the number of epochs by a factor of 10\n", "and confirm that the loss continues to go down.\n", "\n", "Some tips:\n", "\n", "- Use `--limit_test_batches 0` to turn off testing.\n", "We don't need it because we don't care about generalization\n", "and it's relatively slow because it runs the model autoregressively.\n", "\n", "- Use `--help` and look through the model class args\n", "to find the arguments used to reduce model size.\n", "\n", "- By default, there's lots of regularization to prevent overfitting.\n", "Look through the args for the model class and data class\n", "for regularization knobs to turn off or down." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab03_transformers.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab05/notebooks/lab04_experiments.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 04: Experiment Management" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How experiment management brings observability to ML model development\n", "- Which features of experiment management we use in developing the Text Recognizer\n", "- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 4\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This lab contains a large number of embedded iframes\n", "that benefit from having a wide window.\n", "The cell below makes the notebook as wide as your browser window\n", "if `full_width` is set to `True`.\n", "Full width is the default behavior in Colab,\n", "so this cell is intended to improve the viewing experience in other Jupyter environments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Why experiment management?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand why we need experiment management for ML development,\n", "let's start by running an experiment.\n", "\n", "We'll train a new model on a new dataset,\n", "using the training script `training/run_experiment.py`\n", "introduced in [Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a CNN encoder and Transformer decoder, as in\n", "[Lab 03](https://fsdl.me/lab03-colab),\n", "but with some changes so we can iterate faster.\n", "We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n", "[Lab02b](https://fsdl.me/lab02b-colab),\n", "and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n", "from text_recognizer.data import IAMLines # processed version split into individual lines\n", "from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n", "\n", "\n", "print(IAM.__doc__)\n", "\n", "# uncomment a line below for details on either class\n", "# IAMLines?? \n", "# LineCNNTransformer??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below will train a model on 10% of the data for two epochs.\n", "\n", "It takes up to a few minutes to run on commodity hardware,\n", "including data download and preprocessing.\n", "As it's running, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%%time\n", "import torch\n", "\n", "\n", "gpus = int(torch.cuda.is_available()) \n", "\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the model trains, we're calculating lots of metrics --\n", "loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n", "and reporting them to the terminal.\n", "\n", "This is achieved by the built-in `.log` method\n", "([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n", "of the `LightningModule`,\n", "and it is a very straightforward way to get basic information about your experiment as it's running\n", "without leaving the context where you're running it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning to read\n", "[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n", "is something of a rite of passage for MLEs, but\n", "let's consider what we can't see here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We're missing all metric values except the most recent --\n", "we can see them as they stream in, but they're constantly overwritten.\n", "We also can't associate them with timestamps, steps, or epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We also don't see any system metrics.\n", "We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n", "without launching a separate process.\n", "And even if we do, those values will also not be saved and timestamped,\n", "so we can't correlate them with other things during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- As we continue to run experiments, changing code and opening new terminals,\n", "even the information we have or could figure out now will disappear.\n", "Say you spot a weird error message during training,\n", "but your session ends and the stdout is gone,\n", "so you don't know exactly what it was.\n", "Can you recreate the error?\n", "Which git branch and commit were you on?\n", "Did you have any uncommitted changes? Which arguments did you pass?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Also, model checkpoints containing the parameter values have been saved to disk.\n", "Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n", "As we run more and more experiments,\n", "we'll want to slice and dice them to see if,\n", "say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to save and log all of this information, and more, in order to make our model training\n", "[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n", "in short, so that we can understand, make decisions about, and debug our model training\n", "by looking at logs and source code, without having to recreate it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n", "1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n", "2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n", "3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n", "4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Local Experiment Tracking with Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n", "and it makes logging very easy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n", "\n", "It can also use a logger to push information elsewhere.\n", "\n", "By default, we use\n", "[TensorBoard](https://www.tensorflow.org/tensorboard)\n", "via the Lightning `TensorBoardLogger`,\n", "which has been saving results to the local disk.\n", "\n", "Let's find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest experiment's directory\n", "# by hand, you can just copy and paste it from the terminal\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n", "filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n", "latest_log" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!ls -lh {latest_log}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, we need to launch a TensorBoard server --\n", "much like we need to launch a Jupyter server to use Jupyter notebooks.\n", "\n", "The cells below load an extension that lets you use TensorBoard inside of a notebook\n", "the same way you'd use it from the command line, and then launch it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext tensorboard" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n", "\n", "port = 11717 # pick an open port on your machine\n", "host = \"0.0.0.0\" # allow connections from the internet\n", " # watch out! make sure you turn TensorBoard off\n", "\n", "%tensorboard --logdir {latest_log} --port {port} --host {host}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should see some charts of metrics over time along with some charting controls.\n", "\n", "You can click around in this interface and explore it if you'd like,\n", "but in the next section, we'll see that there are better tools for experiment management." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've run many experiments on this machine,\n", "you can see all of their results by pointing TensorBoard\n", "at the whole `lightning_logs` directory,\n", "rather than just one experiment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For large numbers of experiments, the management experience is not great --\n", "it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n", "\n", "It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n", "which are important workflows as applications mature and teams grow." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n", "\n", "If you run into any issues with the above cells failing to launch,\n", "especially across iterations of this lab, run this cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorboard.manager\n", "\n", "# get the process IDs for all tensorboard instances\n", "pids = [tb.pid for tb in tensorboard.manager.get_all()]\n", "\n", "done_with_tensorboard = False\n", "\n", "if done_with_tensorboard:\n", " # kill processes\n", " for pid in pids:\n", " !kill {pid} 2> /dev/null\n", " \n", " # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n", " !rm -rf {tensorboard.manager._get_info_dir()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment Management with Weights & Biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do we manage experiments when we hit the limits of local TensorBoard?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorBoard is powerful and flexible and very scalable,\n", "but running it requires engineering effort and babysitting --\n", "you're running a database, writing data to it,\n", "and layering a web application over it.\n", "\n", "This is a fairly common workflow for web developers,\n", "but not so much for ML engineers.\n", "\n", "You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n", "and it's as simple as running the command `tensorboard dev upload`\n", "pointed at your logging directory.\n", "\n", "But there are strict limits to this free service:\n", "1GB of tensor data and 1GB of binary data.\n", "A single Text Recognizer model checkpoint is ~100MB,\n", "and that's not particularly large for a useful model.\n", "\n", "Furthermore, all data is public,\n", "so if you upload the inputs and outputs of your model,\n", "anyone who finds the link can see them.\n", "\n", "Overall, tensorboard.dev works very well for certain academic and open projects\n", "but not for industrial ML." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To avoid that narrow permissions and limits issue,\n", "you could use [git LFS](https://git-lfs.github.com/)\n", "to track the binary data and tensor data,\n", "which is more likely to be sensitive than metrics.\n", "\n", "The Hugging Face ecosystem uses TensorBoard and git LFS.\n", "\n", "It includes the Hugging Face Hub, a git server much like GitHub,\n", "but designed first and foremost for collaboration on models and datasets,\n", "rather than collaboration on code.\n", "For example, the Hugging Face Hub\n", "[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n", "and officially has\n", "[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n", "avoiding the\n", "[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n", "that make using git LFS with GitHub expensive.\n", "\n", "However, we prefer to avoid mixing software version control and experiment management.\n", "\n", "First, using the Hub requires maintaining an additional git remote,\n", "which is a hard ask for many engineering teams.\n", "\n", "Secondly, git-style versioning is an awkward fit for logging --\n", "is it really sensible to create a new commit for each logging event while you're watching live?\n", "\n", "Instead, we prefer to use systems that solve experiment management with _databases_." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n", "The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n", "tool is [MLflow](https://github.com/mlflow/mlflow/)\n", "and there are a number of\n", "[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n", "\n", "These tools generally avoid any need to worry about hosting\n", "(unless data governance rules require a self-hosted version).\n", "\n", "For a sampling of publicly-posted opinions on experiment management tools,\n", "see these discussions from Reddit:\n", "\n", "- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n", "- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n", "\n", "Among these tools, the FSDL recommendation is\n", "[Weights & Biases](https://wandb.ai),\n", "which we believe offers\n", "- the best user experience, both in the Python SDKs and in the graphical interface\n", "- the best integrations with other tools,\n", "including\n", "[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n", "[Keras](https://docs.wandb.ai/guides/integrations/keras),\n", "[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n", "and even\n", "[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n", "and\n", "- the best tools for collaboration.\n", "\n", "Below, we'll take care to point out which logging and management features\n", "are available via generic interfaces in Lightning and which are W&B-specific." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "print(wandb.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding it to our experiment running code is extremely easy,\n", "relative to the features we get, which is\n", "one of the main selling points of W&B.\n", "\n", "We get most of our new experiment management features just by changing a single variable, `logger`, from\n", "`TensorboardLogger` to `WandbLogger`\n", "and adding two lines of code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see what each of these lines does for us below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this logger is built into and maintained by PyTorch Lightning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning.loggers import WandbLogger\n", "\n", "\n", "WandbLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to complete the rest of this notebook,\n", "you'll need a Weights & Biases account.\n", "\n", "As with GitHub the free tier, for personal, academic, and open source work,\n", "is very generous.\n", "\n", "The Text Recognizer project will fit comfortably within the free tier.\n", "\n", "Run the cell below and follow the prompts to log in or create an account or go\n", "[here](https://wandb.ai/signup)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb login" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the cell below to launch an experiment tracked with Weights & Biases.\n", "\n", "The experiment can take between 3 and 10 minutes to run.\n", "In that time, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n", " --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1\n", " \n", "last_expt = wandb.run\n", "\n", "wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see some new things in our output.\n", "\n", "For example, there's a note from `wandb` that the data is saved locally\n", "and also synced to their servers.\n", "\n", "There's a link to a webpage for viewing the logged data and a name for our experiment --\n", "something like `dandy-sunset-1`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local logging and cloud syncing happens with minimal impact on performance,\n", "because `wandb` launches a separate process to listen for events and upload them.\n", "\n", "That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Runs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, head to the link in the notebook output\n", "that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n", "\n", "There's no need to wait for training to finish.\n", "\n", "The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(last_expt.url)\n", "IFrame(last_expt.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have landed on the run page\n", "([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n", "which collects up all of the information for a single experiment into a collection of tabs.\n", "\n", "We'll work through these tabs from top to bottom.\n", "\n", "Each header is also a link to the documentation for a tab." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n", "This tab has an icon that looks like `(i)` or 🛈.\n", "\n", "The top section of this tab has high-level information about our run:\n", "- Timing information, like start time and duration\n", "- System hardware, hostname, and basic environment info\n", "- Git repository link and state\n", "\n", "This information is collected and logged automatically.\n", "\n", "The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n", "and summary metrics.\n", "\n", "Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n", "\n", "This tab has a line plot icon, something like 📈.\n", "\n", "It's also the default page you land on when looking at a W&B run.\n", "\n", "Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n", "\n", "Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n", "\n", "We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n", "This tab has computer chip icon.\n", "\n", "It contains\n", "- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n", "- CPU metrics: memory usage, utilization, thread counts\n", "- Disk and network I/O levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n", "This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n", "\n", "The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n", "This tab has an icon that looks like a stylized command prompt, `>_`.\n", "\n", "It contains information that was printed to the stdout.\n", "\n", "This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n", "\n", "Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelSummary\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n", "\n", "For more on Lightning `Callback`s, see\n", "[Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n", "This tab has a stylized document icon, something like 📄.\n", "\n", "You can use this tab to view any files saved with the `wandb.save`.\n", "\n", "For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n", "which we'll discuss shortly.\n", "\n", "But a few pieces of information automatically collected by W&B end up in this tab.\n", "\n", "Some highlights:\n", " - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n", " - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n", "This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n", "\n", "This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n", "\n", "We store two kinds of binary files\n", " - `run_table`s of model inputs and outputs\n", " - `model` checkpoints\n", "\n", "We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive Tables of Logged Media" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returning to the Charts tab,\n", "notice that we have model inputs and outputs logged in structured tables\n", "under the train, validation, and test sections.\n", "\n", "These tables are interactive as well\n", "([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n", "They support basic exploratory data analysis and are compatible with W&B's collaboration features." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to charts in our run page, these tables also have their own pages inside the W&B web app." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n", "table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n", "\n", "print(table_data_url)\n", "IFrame(src=table_data_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting this to work requires more effort and more W&B-specific code\n", "than the other features we've seen so far.\n", "\n", "We'll briefly explain the implementation here, for those who are interested.\n", "\n", "We use a custom Lightning `Callback`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n", "\n", "\n", "ImageToTextTableLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n", "\n", "The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n", "\n", "This behavior is sensible for metrics, which are low overhead, but not so much for media,\n", "where we'd rather subsample and avoid holding on to too much information.\n", "\n", "So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n", "\n", "The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.lit_models.base import BaseImageToTextLitModel\n", "\n", "BaseImageToTextLitModel.add_on_logged_batches??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Projects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Everything we've seen so far has been related to a single run or experiment.\n", "\n", "Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n", "\n", "We organize our runs into \"projects\" and view them on the W&B \"project page\" \n", "([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n", "\n", "By default in the Lightning integration, the project name is determined based on directory information.\n", "This default can be over-ridden in the code when creating a `WandbLogger`,\n", "but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what the project page looks like for a longer-running project with lots of experiments.\n", "\n", "The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n", "\n", "print(project_url)\n", "IFrame(src=project_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n", "\n", "We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Artifacts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Beyond logging metrics and metadata from runs,\n", "we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below pulls up all of the artifacts associated with the experiment we just ran." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Click on one of the `model` checkpoints -- the specific version doesn't matter.\n", "\n", "There are a number of tabs here.\n", "\n", "The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n", "\n", "The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n", "which are added by default by the `WandbLogger`.\n", "\n", "The \"Files\" tab contains the actual file contents of the artifact.\n", "\n", "On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n", "including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n", "\n", "You can click on these to explore the different versions and even directly compare them.\n", "\n", "If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Artifact storage is part of the W&B free tier.\n", "\n", "The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n", "\n", "The former is sufficient to store ~700 model checkpoints for the Text Recognizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can track your data storage and compare it to your limits at this URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n", "\n", "print(storage_tracker_url)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Programmatic Access" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also programmatically access our data and metadata via the `wandb` API\n", "([docs](https://docs.wandb.ai/guides/track/public-api-guide)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wb_api = wandb.Api()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run = wb_api.run(\"/\".join( # fetch a run given\n", " [last_expt.entity, # the user or org it was logged to\n", " last_expt.project, # the \"project\", usually one of several per repo/application\n", " last_expt.id] # and a unique ID\n", "))\n", "\n", "hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n", "\n", "hist.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hist.groupby(\"epoch\")[\"train/loss\"].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this includes the artifacts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# which artifacts where created and logged?\n", "artifacts = run.logged_artifacts()\n", "\n", "for artifact in artifacts:\n", " print(f\"artifact of type {artifact.type}: {artifact.name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thanks to our `ImageToTextTableLogger`,\n", "we can easily recreate training or validation data that came out of our `DataLoader`s,\n", "which is normally ephemeral:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n", "artifact_dir = Path(artifact.download(root=\"training/logs\"))\n", "image_dir = artifact_dir / \"media\" / \"images\"\n", "\n", "images = [path for path in image_dir.iterdir()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "from IPython.display import Image\n", "\n", "Image(str(random.choice(images)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Advanced W&B API Usage: MLOps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the strengths of a well-instrumented experiment tracking system is that it allows\n", "automatic relation of information:\n", "what were the inputs when this model's gradient spiked?\n", "Which models have been trained on this dataset,\n", "and what was their performance?\n", "\n", "Having access and automation around this information is necessary for \"MLOps\",\n", "which applies contemporary DevOps principles to ML projects." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cells below pull down the training data\n", "for the model currently running the FSDL Text Recognizer app.\n", "\n", "This is just intended as a demonstration of what's possible,\n", "so don't worry about understanding every piece of this,\n", "and feel free to skip past it.\n", "\n", "MLOps is still a nascent field, and these tools and workflows are likely to change.\n", "\n", "For example, just before the course launched, W&B released a\n", "[Model Registry layer](https://docs.wandb.ai/guides/models)\n", "on top of artifact logging that aims to improve the developer experience for these workflows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start from the same project we looked at in the project view:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n", "\n", "text_recognizer_project " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we search it for the text recognizer model currently being used in production:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# collect all versions of the text-recognizer ever put into production by...\n", "\n", "for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n", " if art_type.name == \"prod-ready\": # for the prod-ready type\n", " # and grabbing the text-recognizer\n", " production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n", "\n", "# and then get the one that's currently being tested in CI by...\n", "for text_recognizer in production_text_recognizers:\n", " if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n", " in_prod_text_recognizer = text_recognizer\n", "\n", "# view its metadata at the url or in the notebook\n", "in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n", "\n", "print(in_prod_text_recognizer_url)\n", "IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From its metadata, we can get information about how it was \"staged\" to be put into production,\n", "and in particular which model checkpoint was used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "staging_run = in_prod_text_recognizer.logged_by()\n", "\n", "training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n", "training_ckpt.name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That checkpoint was logged by a training experiment, which is available as metadata.\n", "\n", "We can look at the training run for that model, either here in the notebook or at its URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "training_run = training_ckpt.logged_by()\n", "print(training_run.url)\n", "IFrame(src=training_run.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And from there, we can access logs and metadata about training,\n", "confident that we are working with the model that is actually in production.\n", "\n", "For example, we can pull down the data we logged and analyze it locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_results = training_run.history(samples=10000)\n", "training_results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n", "training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 10\n", "training_results[\"validation/loss\"].dropna().iloc[10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The charts and webpages in Weights & Biases\n", "are substantially more useful than ephemeral stdouts or raw logs on disk.\n", "\n", "If you're spun up on the project,\n", "they accelerate debugging, exploration, and discovery.\n", "\n", "If not, they're not so much useful as they are overwhelming.\n", "\n", "We need to synthesize the raw logged data into information.\n", "This helps us communicate our work with other stakeholders,\n", "preserve knowledge and prevent repetition of work,\n", "and surface insights faster.\n", "\n", "These workflows are supported by the W&B Reports feature\n", "([docs here](https://docs.wandb.ai/guides/reports)),\n", "which mix W&B charts and tables with explanatory markdown text and embeds.\n", "\n", "Below are some common report patterns and\n", "use cases and examples of each." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the examples are from the FSDL Text Recognizer project.\n", "You can find more of them\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n", "where we've organized them into a report!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dashboard Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dashboards are a structured subset of the output from one or more experiments,\n", "designed for quickly surfacing issues or insights,\n", "like an accuracy or performance regression\n", "or a change in the data distribution.\n", "\n", "Use cases:\n", "- show the basic state of ongoing experiment\n", "- compare one experiment to another\n", "- select the most important charts so you can spin back up into context on a project more quickly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n", "\n", "IFrame(src=dashboard_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull Request Documentation Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In most software codebases,\n", "pull requests are a key focal point\n", "for units of work that combine\n", "short-term communication and long-term information tracking.\n", "\n", "In ML codebases, it's more difficult to bring\n", "sufficient information together to make PRs as useful.\n", "At FSDL, we like to add documentary\n", "reports with one or a small number of charts\n", "that connect logged information in the experiment management system\n", "to state in the version control software.\n", "\n", "Use cases:\n", "- communication of results within a team, e.g. code review\n", "- record-keeping that links pull request pages to raw logged info and makes it discoverable\n", "- improving confidence in PR correctness" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n", "\n", "IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Blog Post Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With sufficient effort, the logged data in the experiment management system\n", "can be made clear enough to be consumed,\n", "sufficiently contextualized to be useful outside the team, and\n", "even beautiful.\n", "\n", "The result is a report that's closer to a blog post than a dashboard or internal document.\n", "\n", "Use cases:\n", "- communication between teams or vertically in large organizations\n", "- external technical communication for branding and recruiting\n", "- attracting users or contributors\n", "\n", "Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma)\n", "and others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n", "\n", "IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many of our choices, like the depth of our network, the nonlinearities of our layers,\n", "and the learning rate and other parameters of our optimizer, cannot be\n", "([easily](https://arxiv.org/abs/1606.04474))\n", "chosen by descent of the gradient of a loss function.\n", "\n", "But these parameters that impact the values of the parameters\n", "we directly optimize with gradients, or _hyperparameters_,\n", "can still be optimized,\n", "essentially by trying options and selecting the values that worked best.\n", "\n", "In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n", "\n", "Expending more compute can squeeze small amounts of additional validation or test performance\n", "that makes for impressive results on leaderboards but typically doesn't translate\n", "into better user experience.\n", "\n", "In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n", "built into your other tooling.\n", "\n", "Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n", "([docs](https://docs.wandb.ai/guides/sweeps)).\n", "\n", "It also supports a number of more advanced tools, like\n", "[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n", "for early termination of poorly-performing runs.\n", "\n", "We can use the same training script and we don't need to run an optimization server.\n", "\n", "We just need to write a configuration yaml file\n", "([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n", "like the one below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training/simple-overfit-sweep.yaml\n", "# first we specify what we're sweeping\n", "# we specify a program to run\n", "program: training/run_experiment.py\n", "# we optionally specify how to run it, including setting default arguments\n", "command: \n", " - ${env}\n", " - ${interpreter}\n", " - ${program}\n", " - \"--wandb\"\n", " - \"--overfit_batches\"\n", " - \"1\"\n", " - \"--log_every_n_steps\"\n", " - \"25\"\n", " - \"--max_epochs\"\n", " - \"100\"\n", " - \"--limit_test_batches\"\n", " - \"0\"\n", " - ${args} # these arguments come from the sweep parameters below\n", "\n", "# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n", "method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n", "metric:\n", " name: train/loss\n", " goal: minimize\n", "parameters: \n", " # LineCNN hyperparameters\n", " window_width:\n", " values: [8, 16, 32, 64]\n", " window_stride:\n", " values: [4, 8, 16, 32]\n", " # Transformer hyperparameters\n", " tf_layers:\n", " values: [1, 2, 4, 8]\n", " # we can also fix some values, just like we set default arguments\n", " gpus:\n", " value: 1\n", " model_class:\n", " value: LineCNNTransformer\n", " data_class:\n", " value: IAMLines\n", " loss:\n", " value: transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the config we launch a \"controller\":\n", "a lightweight process that just decides what hyperparameters to try next\n", "and coordinates the heavierweight training.\n", "\n", "This lives on the W&B servers, so there are no headaches about opening ports for communication,\n", "cleaning up when it's done, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n", "simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we can launch an \"agent\" to follow the orders of the controller:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "\n", "# interrupt twice to terminate this cell if it's running too long,\n", "# it can be over 15 minutes with some hyperparameters\n", "\n", "!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n", "\n", "If not provided, the agent will run forever for random or Bayesian sweeps\n", "or until the sweep is terminated, which can be done from the W&B interface." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The agents make for a slick workflow for distributing sweeps across GPUs.\n", "\n", "We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n", "which controls which GPUs are accessible by a process, to launch\n", "parallel agents on separate GPUs on the same machine." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n", "# open another terminal\n", "CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n", "# and so on\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include optional exercises with the labs for learners who want to dive deeper on specific topics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟Contribute to a hyperparameter search." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n", "\n", "There are ~10,000,000 potential hyperparameter combinations,\n", "and each takes 30 minutes to test,\n", "so checking each possibility will take over 500 years of compute time.\n", "Best get cracking then!\n", "\n", "Run the cell below to pull up a dashboard and print the URL where you can check on the current status." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_entity = \"fullstackdeeplearning\"\n", "sweep_project = \"fsdl-line-recognizer-2022\"\n", "sweep_id = \"e0eo43eu\"\n", "sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n", "\n", "print(sweep_url)\n", "IFrame(src=sweep_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve information about the sweep from the API,\n", "including the hyperparameters being swept over." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparams = sweep_info.config[\"parameters\"]\n", "hyperparams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you'd like to contribute to this sweep,\n", "run the cell below after changing the count to a number greater than 0.\n", "\n", "Each iteration runs for 30 minutes if it does not crash,\n", "e.g. due to out-of-memory errors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "count = 0 # off by default, increase it to join in!\n", "\n", "if count:\n", " !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Write some manual logging in `wandb`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the FSDL Text Recognizer codebase,\n", "we almost exclusively log to W&B through Lightning,\n", "rather than through the `wandb` Python SDK.\n", "\n", "If you're interested in learning how to use W&B directly, e.g. with another training framework,\n", "try out this quick exercise that introduces the key players in the SDK." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n", "\n", "It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n", "\n", "Add W&B metric and artifact logging to this cell:\n", "- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n", "- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import random\n", "\n", "import wandb\n", "\n", "\n", "os.makedirs(\"logs\", exist_ok=True)\n", "\n", "project = \"trying-wandb\"\n", "config = {\"steps\": 50}\n", "\n", "\n", "with wandb.init(project=project, config=config) as run:\n", " steps = wandb.config[\"steps\"]\n", " \n", " for ii in range(steps):\n", " loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n", " \n", " with open(\"logs/hello.txt\", \"w\") as f:\n", " f.write(\"hello from wandb, my dudes!\")\n", " \n", " run_id = run.id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hello_run = wb_api.run(f\"{project}/{run_id}\")\n", "\n", "# check for logged loss data\n", "if \"loss\" not in hello_run.history().keys():\n", " print(\"loss not logged 🥲\")\n", "else:\n", " print(\"loss logged successfully 🥞\")\n", " if len(hello_run.history()[\"loss\"]) != steps:\n", " print(\"loss not logged on all steps 🥲\")\n", " else:\n", " print(\"loss logged on all steps 🥞\")\n", "\n", "artifacts = hello_run.logged_artifacts()\n", "\n", "# check for artifact with the right name\n", "if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n", " print(\"hello artifact not logged 🥲\")\n", "else:\n", " print(\"hello artifact logged successfully 🥞\")\n", " # check for the file inside the artifacts\n", " if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n", " print(\"could not find hello.txt 🥲\")\n", " else:\n", " print(\"hello.txt logged successfully 🥞\")\n", " \n", " \n", "hello_run" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you observe interesting phenomena during training,\n", "from promising hyperparameter combos to software bugs to strange model behavior,\n", "turn the charts into a W&B report and share it with the FSDL community or\n", "[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n", "with a link to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# check the sweep_info.config above to see the model and data hyperparameters\n", "# read through the --help output for all potential arguments\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n", " --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n", " --help # remove this line to run an experiment instead of printing help\n", " \n", "last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟🌟🌟 Add logging of tensor statistics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to logging model inputs and outputs as human-interpretable media,\n", "it's also frequently useful to see information about their numerical values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're interested in learning more about metric calculation and logging with Lightning,\n", "use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n", "to add tensor statistic logging to the `LineCNNTransformer`.\n", "\n", "`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n", "\n", "All three are useful, but start by adding just one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n", "- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n", " - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n", "- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n", " - Base your code on the calculation and logging of the `val_cer` metric.\n", " - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ", "collapsed_sections": [], "name": "lab04_experiments.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab05/notebooks/lab05_troubleshooting.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 05: Troubleshooting & Testing" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n", "- How to implement tests for ML training systems in particular\n", "- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 5\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sThWeTtV6fL_" }, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "xFP8lU4nSg1P" }, "source": [ "# Linting Python and Shell Scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "cXbdYfFlPhZ-" }, "source": [ "### Automatically linting with `pre-commit`" ] }, { "cell_type": "markdown", "metadata": { "id": "ysqqb2GjvLrz" }, "source": [ "We want keep our code clean and uniform across developers\n", "and time.\n", "\n", "Applying the cleanliness checks and style rules should be\n", "as painless and automatic as possible.\n", "\n", "For this purpose, we recommend bundling linting tools together\n", "and enforcing them on all commits with\n", "[`pre-commit`](https://pre-commit.com/)." ] }, { "cell_type": "markdown", "metadata": { "id": "XvqtZChKvLr0" }, "source": [ "In addition to running on every commit,\n", "`pre-commit` separates the model development environment from the environments\n", "needed for the linting tools, preventing conflicts\n", "and simplifying maintenance and onboarding." ] }, { "cell_type": "markdown", "metadata": { "id": "Y0XuIuKOXhJl" }, "source": [ "This cell runs `pre-commit`.\n", "\n", "The first time it is run on a machine, it will install the environments for all tools." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hltYGbpNvLr1" }, "outputs": [], "source": [ "!pre-commit run --all-files" ] }, { "cell_type": "markdown", "metadata": { "id": "gLw08gIkvLr1" }, "source": [ "The output lists all the checks that are run and whether they are passed.\n", "\n", "Notice there are a number of simple version-control hygiene practices included\n", "that aren't even specific to Python, much less to machine learning.\n", "\n", "For example, several of the checks prevent accidental commits with private keys, large files, \n", "leftover debugger statements, or merge conflict annotations in them." ] }, { "cell_type": "markdown", "metadata": { "id": "RHEEjb9kvLr1" }, "source": [ "These linting actions are configured via\n", "([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n", "a YAML file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dgXa8BzrvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml" ] }, { "cell_type": "markdown", "metadata": { "id": "8HYc_WbTvLr2" }, "source": [ "Most of the general cleanliness checks are from hooks built by `pre-commit`.\n", "\n", "See the comments and links in the `.pre-commit-config.yaml` for more:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K9rTgRqzvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep repos -A 15" ] }, { "cell_type": "markdown", "metadata": { "id": "1ptkO7aPvLr2" }, "source": [ "Let's take a look at the section of the file\n", "that applies most of our Python style enforcement with\n", "[`flake8`](https://flake8.pycqa.org/en/latest/):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ALsRKfcevLr3", "scrolled": true }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10" ] }, { "cell_type": "markdown", "metadata": { "id": "a_Q0BwQUXbg6" }, "source": [ "The majority of the style checking behavior we want comes from the\n", "`additional_dependencies`, which are\n", "[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n", "that extend `flake8`'s list of lints.\n", "\n", "Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n", "\n", "We keep the configuration information for `flake8`\n", "separate from that for `pre-commit`\n", "in case we want to use additional tools with `flake8`,\n", "e.g. if some developers want to integrate it directly into their editor,\n", "and so that if we change away from `.pre-commit`\n", "but keep `flake8` we don't have to\n", "recreate our configuration in a different tool.\n", "\n", "As much as possible, codebases should strive for single sources of truth\n", "and link back to those sources of truth with documentation or comments,\n", "as in the last line above.\n", "\n", "Let's take a look at the contents of `flake8`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "doC_4WQwvLr3" }, "outputs": [], "source": [ "!cat .flake8" ] }, { "cell_type": "markdown", "metadata": { "id": "0Nq6HnyU0M47" }, "source": [ "There's a lot here! We'll focus on the most important bits." ] }, { "cell_type": "markdown", "metadata": { "id": "U4PiB8CPvLr3" }, "source": [ "Linting tools in Python generally work by emitting error codes\n", "with one or more letters followed by three numbers.\n", "The `select` argument picks which error codes we want to check for.\n", "Error codes are matched by prefix,\n", "so for example `B` matches `BTS101` and\n", "`G1` matches `G102` and `G199` but not `ARG404`.\n", "\n", "Certain codes are `ignore`d in the default `flake8` style,\n", "which is done via the `ignore` argument,\n", "and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n", "For example, we rely on `black` to do our formatting,\n", "so we ignore some of `flake8`'s formatting codes.\n", "\n", "Together, these settings define our project's particular style.\n", "\n", "But not every file fits this style perfectly.\n", "Most of the conventions in `black` and `flake8` come from the style-defining\n", "[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n", "which exhorts you to \"know when to be inconsistent\".\n", "\n", "To allow ourselves to be inconsistent when we know we should be,\n", "`flake8` includes `per-file-ignores`,\n", "which let us ignore specific warnings in specific files.\n", "This is one of the \"escape valves\"\n", "that makes style enforcement tolerable.\n", "We can also `exclude` files in the `pre-commit` config itself.\n", "\n", "For details on selecting and ignoring,\n", "see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n", "\n", "For definitions of the error codes from `flake8` itself,\n", "see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n", "Individual extensions list their added error codes in their documentation,\n", "e.g. `darglint` does so\n", "[here](https://github.com/terrencepreilly/darglint#error-codes)." ] }, { "cell_type": "markdown", "metadata": { "id": "NL0TpyPsvLr4" }, "source": [ "The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n", "\n", "You can read more about each in their documentation:\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "markdown", "metadata": { "id": "mFsZC0a7vLr4" }, "source": [ "### Linting via a script and using `shellcheck`" ] }, { "cell_type": "markdown", "metadata": { "id": "RYjpuFwjXkJc" }, "source": [ "To avoid needing to think about `pre-commit`\n", "(was the command `pre-commit run` or `pre-commit check`?)\n", "while developing locally,\n", "we might put our linters into a shell script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXlLFWmavLr4" }, "outputs": [], "source": [ "!cat tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "PPxHpRIB3nbw" }, "source": [ "These kinds of short and simple shell scripts are common in projects\n", "of intermediate size.\n", "\n", "They are useful for adding automation and reducing friction." ] }, { "cell_type": "markdown", "metadata": { "id": "TMuPBpAi2qwl" }, "source": [ "But these scripts are code,\n", "and all code is susceptible to bugs and subject to concerns of style consistency." ] }, { "cell_type": "markdown", "metadata": { "id": "SQRg3ZqXvLr4" }, "source": [ "We can't check these scripts with tools that lint Python code,\n", "so we include a shell script linting tool,\n", "[`shellcheck`](https://www.shellcheck.net/),\n", "in our `pre-commit`.\n", "\n", "More so than checking for correct style,\n", "this tool checks for common bugs or surprising behaviors of shells,\n", "which are unfortunately numerous." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zkfhE1srvLr4" }, "outputs": [], "source": [ "script_filename = \"tasks/lint.sh\"\n", "!pre-commit run shellcheck --files {script_filename}" ] }, { "cell_type": "markdown", "metadata": { "id": "KXU9TRrwvLr4" }, "source": [ "That script has already been tested, so we don't see any errors.\n", "\n", "Try copying over a script you've written yourself or\n", "even from a popular repo that you like\n", "(by adding to the notebook directory or by making a cell\n", "with `%%writefile` at the top)\n", "and test it by changing the `script_filename`.\n", "\n", "You'd be surprised at the classes of subtle bugs possible in bash!" ] }, { "cell_type": "markdown", "metadata": { "id": "81MhAL-TvLr5" }, "source": [ "### Try \"unofficial bash strict mode\" for louder failures in scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "hSwhs_zUvLr5" }, "source": [ "Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n", "[@redsymbol](https://twitter.com/redsymbol),\n", "which appear at the top of the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o-j0vSxEvLr5" }, "outputs": [], "source": [ "!head -n 3 tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "d2iJU5jlvLr5" }, "source": [ "The core idea of strict mode is to fail more loudly.\n", "This is a desirable behavior of scripts,\n", "like the ones we're writing,\n", "even though it's an undesirable behavior for an interactive shell --\n", "it would be unpleasant to be logged out every time you hit an error.\n", "\n", "`set -u` means scripts fail if a variable's value is `u`nset,\n", "i.e. not defined.\n", "Otherwise bash is perfectly happy to allow you to reference undefined variables.\n", "The result is just an empty string, which can lead to maddeningly weird behavior.\n", "\n", "`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n", "rather than using the exit code of the last command.\n", "Unix tools are perfectly happy to work on nonsense input,\n", "like sorting error messages, instead of the filenames you meant to send.\n", "\n", "You can read more about these choices\n", "[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n", "and considerations for working with other non-conforming scripts in \"strict mode\"\n", "and for handling resource teardown when scripts error out." ] }, { "cell_type": "markdown", "metadata": { "id": "s1XqsrU_XWWS" }, "source": [ "# Testing ML Codebases" ] }, { "cell_type": "markdown", "metadata": { "id": "CPNzeq3NYF2W" }, "source": [ "## Testing Python code with `pytests`" ] }, { "cell_type": "markdown", "metadata": { "id": "zq5e_x6gc9Vu" }, "source": [ "\n", "ML codebases are Python first and foremost, so first let's get some Python tests going." ] }, { "cell_type": "markdown", "metadata": { "id": "0DC3GxYz6_R9" }, "source": [ "At a basic level,\n", "we can write functions that `assert`\n", "that our code behaves as expected in\n", "a given scenario and include it in the same module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rvd-GNwv63W1" }, "outputs": [], "source": [ "from text_recognizer.lit_models.metrics import test_character_error_rate\n", "\n", "test_character_error_rate??" ] }, { "cell_type": "markdown", "metadata": { "id": "iVB2TsQS5BTq" }, "source": [ "The standard tool for testing Python code is\n", "[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n", "\n", "We can use it as a command-line tool in a variety of ways,\n", "including to execute these kinds of tests.\n", "\n", "If passed a filename, `pytest` will look for\n", "any classes that start with `Test` or\n", "any functions that start with `test_` and run them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u8sQguyJvLr6", "scrolled": false }, "outputs": [], "source": [ "!pytest text_recognizer/lit_models/metrics.py" ] }, { "cell_type": "markdown", "metadata": { "id": "92tkBCllvLr6" }, "source": [ "After the results of the tests (pass or fail) are returned,\n", "you'll see a report of \"coverage\" from\n", "[`codecov`](https://about.codecov.io/).\n", "\n", "This coverage report tells us which files and how many lines in those files\n", "were at touched by the testing suite." ] }, { "cell_type": "markdown", "metadata": { "id": "PllSUe0s5xvU" }, "source": [ "We do not actually need to provide the names of files with tests in them to `pytest`\n", "in order for it to run our tests." ] }, { "cell_type": "markdown", "metadata": { "id": "4qOBHJnTZM9x" }, "source": [ "By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n", "\n", "It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n", "to separate these from the rest of your code\n", "in a folder or folders named `tests`,\n", "rather than scattering them around the repo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "acjsYTNSvLr6" }, "outputs": [], "source": [ "!ls text_recognizer/tests" ] }, { "cell_type": "markdown", "metadata": { "id": "WZQQZUF0vLr6" }, "source": [ "Let's take a look at a specific example:\n", "the tests for some of our utilities around\n", "custom PyTorch Lightning `Callback`s." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oS0xKv1evLr6" }, "outputs": [], "source": [ "from text_recognizer.tests import test_callback_utils\n", "\n", "\n", "test_callback_utils.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "lko8msn-vLr7" }, "source": [ "Notice that we can easily import this as a module!\n", "\n", "That's another benefit of organizing tests into specialized files." ] }, { "cell_type": "markdown", "metadata": { "id": "5A85FUNv75Fr" }, "source": [ "The particular utility we're testing\n", "here is designed to prevent crashes:\n", "it checks for a particular type of error and turns it into a warning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jl4-DiVe76sw" }, "outputs": [], "source": [ "from text_recognizer.callbacks.util import check_and_warn\n", "\n", "check_and_warn??" ] }, { "cell_type": "markdown", "metadata": { "id": "B6E0MhduvLr7" }, "source": [ "Error-handling code is a common cause of bugs,\n", "a fact discovered\n", "[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n", "so it's very important to test it well!\n", "\n", "We start with a very basic test,\n", "which does not touch anything\n", "outside of the Python standard library,\n", "even though this tool is intended to be used\n", "with more complex features of third-party libraries,\n", "like `wandb` and `tensorboard`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xx5koQmJvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_simple??" ] }, { "cell_type": "markdown", "metadata": { "id": "MZe9-JVjvLr7" }, "source": [ "Here, we are just testing the core logic.\n", "This test won't catch many bugs,\n", "but when it does fail, something has gone seriously wrong.\n", "\n", "These kinds of tests are important for resolving a bug:\n", "we learn nearly as much from the tests that passed\n", "as we did from the tests that failed.\n", "If this test has failed, possibly along with others,\n", "we can rule out an issue in one of the large external codebases\n", "touched in the other tests, saving us lots of time in our troubleshooting.\n", "\n", "The reasoning for the test is explained in the docstrings, \n", "which are close to the code.\n", "\n", "Your test suite should be as welcoming\n", "as the rest of your codebase!\n", "The people reading it, for example yourself in six months, \n", "are likely upset and in need of some kindness.\n", "\n", "More practically, we want keep our time to resolve errors as short as possible,\n", "and five minutes to write a good docstring now\n", "can save five minutes during an outage, when minutes really matter." ] }, { "cell_type": "markdown", "metadata": { "id": "Om9k-uXhvLr7" }, "source": [ "That basic test is a start, but it's not enough by itself.\n", "There's a specific error case that triggered the addition of this code.\n", "\n", "So we test that it's handled as expected." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjbsb5FvvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_tblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "CGAIZTUjvLr7" }, "source": [ "That test can fail if the libraries change around our code,\n", "i.e. if the `TensorBoardLogger` gets a `log_table` method.\n", "\n", "We want to be careful when making assumptions\n", "about other people's software,\n", "especially for fast-moving libraries like Lightning.\n", "If we test that those assumptions hold willy-nilly,\n", "we'll end up with tests that fail because of\n", "harmless changes in our dependencies.\n", "\n", "Tests that require a ton of maintenance and updating\n", "without leading to code improvements soak up\n", "more engineering time than they save\n", "and cause distrust in the testing suite.\n", "\n", "We include this test because `TensorBoardLogger` getting\n", "a `log_table` method will _also_ change the behavior of our code\n", "in a breaking way, and we want to catch that before it breaks\n", "a model training job." ] }, { "cell_type": "markdown", "metadata": { "id": "jsy95KAvvLr7" }, "source": [ "Adding error handling can also accidentally kill the \"happy path\"\n", "by raising an error incorrectly.\n", "\n", "So we explicitly test the _absence of an error_,\n", "not just its presence:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LRlIOkjmvLr8" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_wandblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "osiqpLynvLr8" }, "source": [ "There are more tests we could build, e.g. manipulating classes and testing the behavior,\n", "testing more classes that might be targeted by `check_and_warn`, or\n", "asserting that warnings are raised to the command line.\n", "\n", "But these three basic tests are likely to catch most changes that would break our code here,\n", "and they're a lot easier to write than the others.\n", "\n", "If this utility starts to get more usage and become a critical path for lots of features, we can always add more!" ] }, { "cell_type": "markdown", "metadata": { "id": "dm285JE5vLr8" }, "source": [ "## Interleaving testing and documentation with `doctests`" ] }, { "cell_type": "markdown", "metadata": { "id": "UHWQvgA8vLr8" }, "source": [ "One function of tests is to build user/reader confidence in code." ] }, { "cell_type": "markdown", "metadata": { "id": "wrhiJBXFvLr8" }, "source": [ "One function of documentation is to build user/reader knowledge in code." ] }, { "cell_type": "markdown", "metadata": { "id": "1vu12LDhvLr8" }, "source": [ "These functions are related. Let's put them together:\n", "put code in a docstring and test that code.\n", "\n", "This feature is part of the\n", "Python standard library via the\n", "[`doctest` module](https://docs.python.org/3/library/doctest.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "rmfIOwXd-Qt7" }, "source": [ "Here's an example from our `torch` utilities.\n", "\n", "The `first_appearance` function can be used to\n", "e.g. quickly look for stop tokens,\n", "giving the length of each sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZzURGcD9vLr8" }, "outputs": [], "source": [ "from text_recognizer.lit_models.util import first_appearance\n", "\n", "\n", "first_appearance??" ] }, { "cell_type": "markdown", "metadata": { "id": "0VtYcJ1WvLr8" }, "source": [ "Notice that in the \"Examples\" section,\n", "there's a short block of code formatted as a\n", "Python interpreter session,\n", "complete with outputs.\n", "\n", "We can copy and paste that code and\n", "check that we get the right outputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dj4lNOxJvLr9" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y9AWHFoIvLr9" }, "source": [ "We can run the test with `pytest` by passing a command line argument,\n", "`--doctest-modules`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JMaAxv5ovLr9" }, "outputs": [], "source": [ "!pytest --doctest-modules text_recognizer/lit_models/util.py" ] }, { "cell_type": "markdown", "metadata": { "id": "6-2_aOUfvLr9" }, "source": [ "With the\n", "[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n", "running `doctest`s happens automatically\n", "when `pytest` is invoked." ] }, { "cell_type": "markdown", "metadata": { "id": "my_keokPvLr9" }, "source": [ "## Basic tests for data code" ] }, { "cell_type": "markdown", "metadata": { "id": "Qj3Bq_j2_A8o" }, "source": [ "ML code can be hard to test\n", "since it involes very heavy artifacts, like models and data,\n", "and very expensive jobs, like training." ] }, { "cell_type": "markdown", "metadata": { "id": "DT5OmgrQvLr9" }, "source": [ "For testing our data-handling code in the FSDL codebase,\n", "we mostly just use `assert`s,\n", "which throw errors when behavior differs from expectation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bdzn5g4TvLr9" }, "outputs": [], "source": [ "!grep \"assert\" -r text_recognizer/data" ] }, { "cell_type": "markdown", "metadata": { "id": "2aTlfu4_vLr-" }, "source": [ "This isn't great practice,\n", "especially as a codebase grows,\n", "because we can't easily know when these are executed\n", "or incorporate them into\n", "testing automation and coverage analysis tools." ] }, { "cell_type": "markdown", "metadata": { "id": "IaMTdmbZ_mkW" }, "source": [ "So it's preferable to collect up these assertions of simple data properties\n", "into tests that are run like our other tests.\n", "\n", "The test below checks whether any data is leaking\n", "between training, validation, and testing." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qx7cxiDdvLr-" }, "outputs": [], "source": [ "from text_recognizer.tests.test_iam import test_iam_data_splits\n", "\n", "\n", "test_iam_data_splits??" ] }, { "cell_type": "markdown", "metadata": { "id": "16TJwhd1vLr-" }, "source": [ "Notice that we were able to load the test into the notebook\n", "because it is in a module,\n", "and so we can run it here as well:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mArITFkYvLr-" }, "outputs": [], "source": [ "test_iam_data_splits()" ] }, { "cell_type": "markdown", "metadata": { "id": "E4F2uaclvLr-" }, "source": [ "But we're checking something pretty simple here,\n", "so the new code in each test is just a single line.\n", "\n", "What if we wanted to test more complex properties,\n", "like comparing rows or calculating statistics?\n", "\n", "We'll end up writing more complex code that might itself have subtle bugs,\n", "requiring tests for our tests and suffering from\n", "\"tester's regress\".\n", "\n", "This is the phenomenon,\n", "named by analogy with\n", "[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n", "in sociology of science,\n", "where the validity of our tests is itself\n", "up for dispute only resolvable by testing the tests,\n", "but those tests are themselves possibly invalid." ] }, { "cell_type": "markdown", "metadata": { "id": "nUGT06gdvLr-" }, "source": [ "We cut this Gordian knot by using\n", "a library or framework that is well-tested.\n", "\n", "We recommend checking out\n", "[`great_expectations`](https://docs.greatexpectations.io/docs/)\n", "if you're looking for a high-quality data testing tool." ] }, { "cell_type": "markdown", "metadata": { "id": "dQ5vNsq3vLr-" }, "source": [ "Especially with data, some tests are particularly \"heavy\" --\n", "they take a long time,\n", "and we might want to run them\n", "on different machines\n", "and on a different schedule\n", "than our other tests." ] }, { "cell_type": "markdown", "metadata": { "id": "xephcb0LvLr-" }, "source": [ "For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n", "\n", "We can't just use a cached version of the data,\n", "since that won't actually execute the code!\n", "\n", "This test will take\n", "as long to run\n", "and consume as many resources as\n", "a full download of the data." ] }, { "cell_type": "markdown", "metadata": { "id": "YSN4w2EqvLr-" }, "source": [ "`pytest` allows the separation of tests\n", "into suites with `mark`s,\n", "which \"tag\" tests with names." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V0rScrcXvLr_", "scrolled": false }, "outputs": [], "source": [ "!pytest --markers | head -n 10" ] }, { "cell_type": "markdown", "metadata": { "id": "lr5Ca7B0vLr_" }, "source": [ "We can choose to run tests with a given mark\n", "or to skip tests with a given mark, \n", "among other basic logical operations around combining and filtering marks,\n", "with `-m`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xmw-Eb1ZvLr_" }, "outputs": [], "source": [ "!wandb login # one test requires wandb authentication\n", "\n", "!pytest -m \"not data and not slow\"" ] }, { "cell_type": "markdown", "metadata": { "id": "5LuERxOXX_UJ" }, "source": [ "## Testing training with memorization tests" ] }, { "cell_type": "markdown", "metadata": { "id": "AnWLN4lRvLsA" }, "source": [ "Training is the process by which we convert inert data into executable models,\n", "so it is dependent on both.\n", "\n", "We decouple checking whether the script has a critical bug\n", "from whether the data or model code is broken\n", "by testing on some basic \"fake data\",\n", "based on a utility from `torchvision`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k4NIc3uWvLsA" }, "outputs": [], "source": [ "from text_recognizer.data import FakeImageData\n", "\n", "\n", "FakeImageData.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "deN0swwlvLsA" }, "source": [ "We then test on the actual data with a smaller version of the real model.\n", "\n", "We use the Lightning `--fast_dev_run` feature,\n", "which sets the number of training, validation, and test batches to `1`.\n", "\n", "We use a smaller version so that this test can run in just a few minutes\n", "on a CPU without acceleration.\n", "\n", "That allows us to run our tests in environments without GPUs,\n", "which saves on costs for executing tests.\n", "\n", "Here's the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z4J0_uD9vLsA" }, "outputs": [], "source": [ "!cat training/tests/test_run_experiment.sh" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-7u9zS1vLsA", "scrolled": false }, "outputs": [], "source": [ "! ./training/tests/test_run_experiment.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "UTzfo11KClV3" }, "source": [ "The above tests don't actaully check\n", "whether any learning occurs,\n", "they just check\n", "whether training runs mechanically,\n", "without any errors.\n", "\n", "We also need a\n", "[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n", "for learning.\n", "For that we recommending checking whether\n", "the model can learn the right\n", "outputs for a single batch --\n", "to \"memorize\" the outputs for\n", "a particular input.\n", "\n", "This memorization test won't\n", "catch every bug or issue in training,\n", "which is notoriously difficult,\n", "but it will flag\n", "some of the most serious issues." ] }, { "cell_type": "markdown", "metadata": { "id": "0DVSp3aAvLsA" }, "source": [ "The script below runs a memorization test." ] }, { "cell_type": "markdown", "metadata": { "id": "2DFVVrxpvLsA" }, "source": [ "It takes up to two arguments:\n", "a `MAX`imum number of `EPOCHS` to run for and\n", "a `CRITERION` value of the loss to test against.\n", "\n", "The test passes if the loss is lower than the `CRITERION` value\n", "after the `MAX`imum number of `EPOCHS` has passed." ] }, { "cell_type": "markdown", "metadata": { "id": "oEhJH0e5vLsB" }, "source": [ "The important line in this script is the one that invokes our training script,\n", "`training/run_experiment.py`.\n", "\n", "The arguments to `run_experiment` have been tuned for maximum possible speed:\n", "turning off regularization, shrinking the model,\n", "and skipping parts of Lightning that we don't want to test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T-fFs1xEvLsB" }, "outputs": [], "source": [ "!cat training/tests/test_memorize_iam.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "X-47tUA_YNGe" }, "source": [ "If you'd like to see what a memorization run looks like,\n", "flip the `running_memorization` flag to `True`\n", "and watch the results stream in to W&B.\n", "\n", "The cell should run in about ten minutes on a commodity GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GwTEsZwKvLsB" }, "outputs": [], "source": [ "%%time\n", "running_memorization = False\n", "\n", "if running_memorization:\n", " max_epochs = 1000\n", " loss_criterion = 0.05\n", " !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Troubleshooting model speed with the PyTorch Profiler" ] }, { "cell_type": "markdown", "metadata": { "id": "DpbN-Om2Drf-" }, "source": [ "Testing code is only half the story here:\n", "we also need to fix the issues that our tests flag.\n", "This is the process of troubleshooting.\n", "\n", "In this lab,\n", "we'll focus on troubleshooting model performance issues:\n", "what do to when your model runs too slowly." ] }, { "cell_type": "markdown", "metadata": { "id": "NZzwELPXvLsD" }, "source": [ "Troubleshooting deep neural networks for speed is challenging.\n", "\n", "There are at least three different common approaches,\n", "each with an increasing level of skill required:\n", "\n", "1. Follow best practices advice from others\n", "([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n", "[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n", "2. Take code that runs slowly and use empirical observations to iteratively improve it.\n", "3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n", "\n", "For the full stack deep learning engineer,\n", "the final level is typically out of reach,\n", "unless you're specializing in the model performance\n", "part of the stack in particular.\n", "\n", "So we recommend reaching the middle level,\n", "and this segment of the lab walks through the\n", "tools that make this easier." ] }, { "cell_type": "markdown", "metadata": { "id": "3_yp87UrFZ8M" }, "source": [ "Because neural network training involves GPU acceleration,\n", "generic Python profiling tools like\n", "[`py-spy`](https://github.com/benfred/py-spy)\n", "won't work, and\n", "we'll need tools specialized for tracing and profiling DNN training." ] }, { "cell_type": "markdown", "metadata": { "id": "yspsYVFGEyZm" }, "source": [ "In general, these tools are for observing what happens while your code is executing:\n", "_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n", "\n", "Because they help us observe the execution in detail,\n", "they will also help us understand just what is going on during\n", "a PyTorch training step in greater detail." ] }, { "cell_type": "markdown", "metadata": { "id": "YqXq2hKuvLsE" }, "source": [ "To support profiling and tracing,\n", "we've added a new argument to `training/run_experiment.py`, `--profile`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z_GMMViWvLsE" }, "outputs": [], "source": [ "!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\"" ] }, { "cell_type": "markdown", "metadata": { "id": "ZldoksHPvLsE" }, "source": [ "As with experiment management, this relies mostly on features of PyTorch Lightning,\n", "which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n", "and we just add a few lines of customization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F2iJ0_A6vLsE" }, "outputs": [], "source": [ "!cat training/run_experiment.py | grep args.profile -A 5" ] }, { "cell_type": "markdown", "metadata": { "id": "Aw3ppgndvLsE" }, "source": [ "For more on profiling with Lightning, see the\n", "[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCAmNW3QEtcD" }, "source": [ "The cell below runs an epoch of training with tracing and profiling turned on\n", "and then saves the results locally and to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t4o3ylDgr46F", "scrolled": false }, "outputs": [], "source": [ "import glob\n", "\n", "import torch\n", "import wandb\n", "\n", "from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n", "\n", "\n", "# make it easier to separate these from training runs\n", "%env WANDB_JOB_TYPE=profile\n", "\n", "batch_size = 16\n", "num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n", "gpus = 1 # must be run with accelerator\n", "\n", "%run training/run_experiment.py --wandb --profile \\\n", " --max_epochs=1 \\\n", " --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n", " --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n", " --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n", "\n", "latest_expt = wandb.run\n", "\n", "try: # add execution trace to logged and versioned binaries\n", " folder = wandb.run.dir\n", " trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n", " trace_file = glob.glob(trace_matcher)[0]\n", " trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n", " trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n", " wandb.log_artifact(trace_at)\n", "except IndexError:\n", " print(\"trace not found\")\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": { "id": "ePTkS3EqO5tN" }, "source": [ "We get out a table of statistics in the terminal,\n", "courtesy of Lightning.\n", "\n", "Each row lists an operation\n", "and and provides information,\n", "described in the column headers,\n", "about the time spent on that operation\n", "across all the training steps we profiled.\n", "\n", "With practice, some useful information can be read out from this table,\n", "but it's better to start from both a less detailed view,\n", "in the TensorBoard dashboard,\n", "and a more detailed view,\n", "using the Chrome Trace viewer." ] }, { "cell_type": "markdown", "metadata": { "id": "TzV62f3c7-Bi" }, "source": [ "## High-level statistics from the PyTorch Profiler in TensorBoard" ] }, { "cell_type": "markdown", "metadata": { "id": "mNPKXkYw8NWd" }, "source": [ "Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CbItwuT88eAV" }, "outputs": [], "source": [ "your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n", "\n", "print(your_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "jE_LooMYHFpF" }, "source": [ "If at any point you run into issues,\n", "like the description not matching what you observe,\n", "check out one of our example runs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "za2zybSwIo5C" }, "outputs": [], "source": [ "example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n", "print(example_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "xlrhl1n4HYU6" }, "source": [ "Once the TensorBoard session has loaded up,\n", "we are dropped into the Overview\n", "(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n", "for an example).\n", "\n", "In the top center, we see the **GPU Summary** for our system.\n", "\n", "In addition to the name of our GPU,\n", "there are a few configuration details and top-level statistics.\n", "They are (tersely) documented\n", "[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)." ] }, { "cell_type": "markdown", "metadata": { "id": "MmBhUDgDLhd1" }, "source": [ "- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n", "this is effectively a coarse \"version number\" for your GPU hardware.\n", "It indexes which features are available,\n", "with more advanced features being available only at higher compute capabilities.\n", "It does not directly index the speed or memory of the GPU." ] }, { "cell_type": "markdown", "metadata": { "id": "voUgT6zuLyi0" }, "source": [ "- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase." ] }, { "cell_type": "markdown", "metadata": { "id": "Yl-IndtXE4b4" }, "source": [ "- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n", "for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n", "Tensor Cores.\n", "If you're running on an older GPU without Tensor Cores,\n", "you should consider upgrading.\n", "If you're running a more recent GPU but not seeing Tensor Core usage,\n", "you should switch to single precision floating point numbers,\n", "which Tensor Cores are specialized on." ] }, { "cell_type": "markdown", "metadata": { "id": "XxcUf0bBNXy_" }, "source": [ "- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n", "at a lower level than just whether something is running at all,\n", "as in utilization.\n", "Unlike utilization, reaching 100% is not generally feasible\n", "and sometimes not desirable.\n", "Increasing these numbers requires expertise in\n", "CUDA programming, so we'll target utilization instead." ] }, { "cell_type": "markdown", "metadata": { "id": "A88pQn4YMMKc" }, "source": [ "- **Execution Summary**: This table and pie chart indicates\n", "how much time within a profiled step\n", "was spent in each category.\n", "The value for \"kernel\" execution here\n", "is equal to the GPU utilization,\n", "and we want that number to be as close to 100%\n", "as possible.\n", "This summary helps us know which\n", "other operations are taking time,\n", "like memory being copied between CPU and GPU (`memcpy`)\n", "or `DataLoader`s executing on the CPU,\n", "so we can decide where the bottleneck is." ] }, { "cell_type": "markdown", "metadata": { "id": "6qjW1RlTQRPv" }, "source": [ "At the very bottom, you'll find a\n", "**Performance Recommendation**\n", "tab that sometimes suggests specific methods for improving performance.\n", "\n", "If this tab makes suggestions, you should certainly take them!" ] }, { "cell_type": "markdown", "metadata": { "id": "pWY5AhrcRQmJ" }, "source": [ "For more on using the profiler in TensorBoard,\n", "including some of the other, more detailed views\n", "available view the \"Views\" dropdown menu, see\n", "[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)." ] }, { "cell_type": "markdown", "metadata": { "id": "mQwrPY_H77H8" }, "source": [ "## Going deeper with the Chrome Trace Viewer" ] }, { "cell_type": "markdown", "metadata": { "id": "yhwo7fslvLsH" }, "source": [ "So far, we've seen summary-level information about our training steps\n", "in the table from Lightning and in the TensorBoard Overview.\n", "These give aggregate statistics about the computations that occurred,\n", "but understanding how to interpret those statistics\n", "and use them to speed up our networks\n", "requires understanding just what is\n", "happening in our training step.\n", "\n", "Fundamentally,\n", "all computations are processes that unfold in time.\n", "\n", "If we want to really understand our training step,\n", "we need to display it that way:\n", "what operations were occurring,\n", "on both the CPU and GPU,\n", "at each moment in time during the training step.\n", "\n", "This information on timing is collected in the trace.\n", "One of the best tools for viewing the trace over time\n", "is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)." ] }, { "cell_type": "markdown", "metadata": { "id": "wUkZItxYc20A" }, "source": [ "Let's tour the trace we just logged\n", "with an aim to really understanding just\n", "what is happening when we call\n", "`training_step`\n", "and by extension `.forward`, `.backward`, and `optimizer.step`." ] }, { "cell_type": "markdown", "metadata": { "id": "9w9F2UA7Qctg" }, "source": [ "The Chrome Trace Viewer is built into W&B,\n", "so we can view our traces in their interface.\n", "\n", "The cell below embeds the trace inside the notebook,\n", "but you may wish to open it separately,\n", "with the \"Open page\" button or by navigating to the URL,\n", "so that you can interact with it\n", "as you read the description below.\n", "Display directly on W&B is also a bit less temperamental\n", "than display on W&B inside a notebook.\n", "\n", "Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n", "so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n", "It also can interact poorly with browser extensions (e.g. ad blockers),\n", "so you may need to deactivate them temporarily in order to see it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OMUs4aby6Rfd" }, "outputs": [], "source": [ "trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n", "trace_url = trace_files_url + \"training_step.pt.trace.json\"\n", "\n", "example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n", "\n", "print(trace_url)\n", "IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qNVpGeQtQjMG" }, "source": [ "> **Heads up!** We're about to do a tour of the\n", "> precise details of the tracing information logged\n", "> during the execution of the training code.\n", "> The only way to learn how to troubleshoot model performance\n", "> empirically is to look at the details,\n", "> but the details depend on the precise machine being used\n", "> -- GPU and CPU and RAM.\n", "> That means even within Colab,\n", "> these details change from session to session.\n", "> So if you don't observe a phenomenon or feature\n", "> described in the tour below, check out\n", "> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n", "> on W&B while reading through the next section of the lab,\n", "> and return to your trace once you understand the trace viewer better at the end.\n", "> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n", "> can sometimes be a bit janky." ] }, { "cell_type": "markdown", "metadata": { "id": "kXMcBhnCgdN_" }, "source": [ "This trace reveals, in nanosecond-level detail,\n", "what's going on inside of a `training_step`\n", "on both the GPU and the CPU.\n", "\n", "Time is on the horizontal axis.\n", "Colored bars represent method calls,\n", "and the methods called by a method are placed underneath it vertically,\n", "a visualization known as an\n", "[icicle chart](https://www.brendangregg.com/flamegraphs.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "67BsNzDfVIeg" }, "source": [ "Let's orient ourselves with some gross features:\n", "the forwards pass,\n", "GPU kernel execution,\n", "the backwards pass,\n", "and the optimizer step." ] }, { "cell_type": "markdown", "metadata": { "id": "IBEFgtRCKqrh" }, "source": [ "### The forwards pass" ] }, { "cell_type": "markdown", "metadata": { "id": "5nYhiWesVMjK" }, "source": [ "Type in `resnet` to the search bar in the top-right.\n", "\n", "This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n", "\n", "It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n", "\n", "You can click the arrows next to that tile to partially collapse these blocks.\n", "\n", "Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n", "It should be at roughly the same height.\n", "\n", "Clear the search bar so that the trace is in color.\n", "Zoom in on the area of the forwards pass\n", "using the \"zoom\" tool in the floating toolbar,\n", "so you can see more detail.\n", "The zoom tool is indicated by a two-headed arrow\n", "pointing into and out of the screen.\n", "\n", "Switch to the \"drag\" tool,\n", "represented by a four-headed arrow.\n", "Click-and-hold to use this tool to focus\n", "on different parts of the timeline\n", "and click on the individual colored boxes\n", "to see details about a particular method call.\n", "\n", "As we go down in the icicle chart,\n", "we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n", "to much more precise `cudnn` and `cuda` operations\n", "(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n", "\n", "`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n", "is the tensor math library in PyTorch\n", "that links to specific backends like `cudnn`." ] }, { "cell_type": "markdown", "metadata": { "id": "Fq181ybIvLsH" }, "source": [ "### GPU kernel execution" ] }, { "cell_type": "markdown", "metadata": { "id": "IbkWp5aKvLsH" }, "source": [ "Towards the bottom, you should see a section labeled \"GPU\".\n", "The label appears on the far left.\n", "\n", "Within it, you'll see one or more \"`stream`s\".\n", "These are units of work on a GPU,\n", "akin loosely to threads on the CPU.\n", "\n", "When there are colored bars in this area,\n", "the GPU is doing work of some kind.\n", "The fraction of this bar that is filled in with color\n", "is the same as the \"GPU Utilization %\" we've seen previously.\n", "So the first thing to visually assess\n", "in a trace view of PyTorch code\n", "is what fraction of this area is filled with color.\n", "\n", "In CUDA, work is queued up to be\n", "placed into streams and completed, on the GPU,\n", "in a distributed and asynchronous manner.\n", "\n", "The selection of which work to do\n", "is happening on the CPU,\n", "and that's what we were looking at above.\n", "\n", "The CPU and the GPU have to work together to coordinate\n", "this work.\n", "\n", "Type `cuda` into the search bar and you'll see these coordination operations happening:\n", "`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n", "\n", "Running the same PyTorch model\n", "with the same high level operations like `Conv2d` in different versions of PyTorch,\n", "on different GPUs, and even on tensors of different sizes will result\n", "in different choices of concrete kernel operation,\n", "e.g. different matrix multiplication algorithms.\n", "\n", "Type `sync` into the search bar and you'll see places where either work on the GPU\n", "or work on the CPU needs to await synchronization,\n", "e.g. copying data from the CPU to the GPU\n", "or the CPU waiting to decide what to do next\n", "on the basis of the contents of a tensor.\n", "\n", "If you see a \"sync\" block above an area\n", "where the stream on the GPU is empty,\n", "you've got a performance bottleneck due to synchronization\n", "between the CPU and GPU.\n", "\n", "To resolve the bottleneck,\n", "head up the icicle chart until you reach the recognizable\n", "PyTorch modules and operations.\n", "Find where they are called in your PyTorch module.\n", "That's a good place to review your code to understand why the synchronization is happening\n", "and removing it if it's not necessary." ] }, { "cell_type": "markdown", "metadata": { "id": "XeMPbu_jvLsI" }, "source": [ "### The backwards pass\n", "\n", "Type in `backward` into the search bar.\n", "\n", "This will highlight components of our backwards pass.\n", "\n", "If you read it from left to right,\n", "you'll see that it begins by calculating the loss\n", "(`NllLoss2DBackward` in the search bar if you can't find it)\n", "and ends by doing a `ConvolutionBackward`,\n", "the first layer of the ResNet.\n", "It is, indeed, backwards.\n", "\n", "Like the forwards pass,\n", "the backwards pass also involves the CPU\n", "telling the GPU which kernels to run.\n", "It's typically run in a separate\n", "thread from the forwards pass,\n", "so you'll see it separated out from the forwards pass\n", "in the trace viewer.\n", "\n", "Generally, there's no need to specifically optimize the backwards pass --\n", "removing bottlenecks in the forwards pass results in a fast backwards pass.\n", "\n", "One reason why is that these two passes are just\n", "\"transposes\" of one another,\n", "so they share a lot of properties,\n", "and bottlenecks in one become bottlenecks in the other.\n", "We can choose to optimize either one of the two.\n", "But the forwards pass is under our direct control,\n", "so it's easier for us to reason about.\n", "\n", "Another reason is that the forwards pass is more likely to have bottlenecks.\n", "The forwards pass is a dynamic process,\n", "with each line of Python adding more to the compute graph.\n", "Backwards passes, on the other hand, use a static compute graph,\n", "the one just defined by the forwards pass,\n", "so more optimizations are possible." ] }, { "cell_type": "markdown", "metadata": { "id": "gWiDw0vCvLsI" }, "source": [ "### The optimizer step" ] }, { "cell_type": "markdown", "metadata": { "id": "ndfkzEdnvLsI" }, "source": [ "Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n", "\n", "As with the two passes,\n", "we are still using the CPU\n", "to launch kernels on the GPU.\n", "But now the CPU is looping,\n", "in Python, over the parameters\n", "and applying the ADAM updates rules to each.\n", "\n", "We now know enough to see that\n", "this is not great for our GPU utilization:\n", "there are many areas of gray\n", "in between the colored bars\n", "in the GPU stream in this area.\n", "\n", "In the time it takes CUDA to multiply\n", "thousands of numbers,\n", "Python has not yet finished cleaning up\n", "after its request for that multiplication.\n", "\n", "As of writing in August 2022,\n", "more efficient optimizers are not a stable part of PyTorch (v1.12), but\n", "[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n", "and stable implementations outside of PyTorch.\n", "The standard implementations are in\n", "[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n", "not to be confused with the\n", "[Apex Optimizers Project](https://www.apexoptimizers.com/),\n", "which is a collection of fitness-themed cheetah NFTs." ] }, { "cell_type": "markdown", "metadata": { "id": "WX0jxeafvLsI" }, "source": [ "## Take-aways for PyTorch performance bottleneck troubleshooting" ] }, { "cell_type": "markdown", "metadata": { "id": "CugD-bK2vLsI" }, "source": [ "Our goal here was to learn some basic principles and tools for bottlenecking\n", "the most common issues and the lowest-hanging fruit in PyTorch code." ] }, { "cell_type": "markdown", "metadata": { "id": "SwHwJkVMHYGA" }, "source": [ "\n", "Here's an overview in terms of a \"host\",\n", "generally the CPU,\n", "and a \"device\", here the GPU.\n", "\n", "- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n", "- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n", " - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n", "- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n", "stores information about the contents of tensors,\n", "not just their metadata." ] }, { "cell_type": "markdown", "metadata": { "id": "Gntx28p9cBP5" }, "source": [ "Towards that goal, we viewed the trace to get an understanding of\n", "what's going on inside a PyTorch training step." ] }, { "cell_type": "markdown", "metadata": { "id": "AKvZGPnkeXvq" }, "source": [ "Here's what we've means in terms of troubleshooting bottlenecks.\n", "\n", "We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n", "before the previous kernel finishes.\n", "\n", "Ideally, the CPU is actually getting far _ahead_ of execution\n", "on the GPU.\n", "If the CPU makes it all the way through the backwards pass before the GPU is done,\n", "that's great!\n", "The GPU(s) are the expensive part,\n", "and it's easy to use multiprocessing so that\n", "the CPU has other things to do.\n", "\n", "This helps explain at least one common piece of advice:\n", "the larger our batches are,\n", "the more work the GPU has to do for the same work done by the CPU,\n", "and so the better our utilization will be." ] }, { "cell_type": "markdown", "metadata": { "id": "XMztpa-TccH4" }, "source": [ "We operationalize our desire to never be waiting on the CPU with a simple metric:\n", "**100% GPU utilization**, meaning a kernel is running at all times.\n", "\n", "This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n", "\n", "You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost." ] }, { "cell_type": "markdown", "metadata": { "id": "7kYBygfScR6z" }, "source": [ "Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n", "1. **The CPU is too weak**.\n", "Because so much of the discussion around DNN performance is about GPUs,\n", "it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n", "_Resolution_:\n", "Use nice CPUs, like\n", "[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n", "2. **Too much Python during the `training_step`**.\n", "Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n", "that takes longer than running an entire layer.\n", "_Resolution_:\n", "Look for low utilization areas of the trace\n", "and check what's happening on the CPU at that time\n", "and carefully review the Python code being executed.\n", "3. **Unnecessary Host/Device synchronization**.\n", "If one of your operations depends on the values in a tensor,\n", "like `if xs.mean() >= 0`,\n", "you'll induce a synchronization between\n", "the host and the device and possibly lead\n", "to an expensive and slow copy of data.\n", "_Resolution_:\n", "Replace these operations as much as possible\n", "with purely array-based calculations.\n", "4. **Bottlenecking on the DataLoader**.\n", "In addition to coordinating the work on the GPU,\n", "CPUs often perform heavy data operations,\n", "including communication over the network\n", "and writing to/reading from disk.\n", "These are generally done in parallel to the forwards\n", "and backwards passes,\n", "but if they don't finish before that happens,\n", "they will become the bottleneck.\n", "_Resolution_:\n", "Get better hardware for compute,\n", "memory, and network.\n", "For software solutions, the answer \n", "is a bit more complex and application-dependent.\n", "For generic tips, see\n", "[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n", "in the PyTorch forums.\n", "For techniques in computer vision, see\n", "[the FFCV library](https://github.com/libffcv/ffcv)\n", "and for techniques in NLP, see e.g.\n", "[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n", "and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)." ] }, { "cell_type": "markdown", "metadata": { "id": "i2WYS8bQvLsJ" }, "source": [ "### Further steps in making DNNs go brrrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "T0wW2_lRKfY1" }, "source": [ "It's important to note that utilization\n", "is just an easily measured metric\n", "that can reveal common bottlenecks.\n", "Having high utilization does not automatically mean\n", "that your performance is fully optimized.\n", "\n", "For example,\n", "synchronization events between GPUs\n", "are counted as kernels,\n", "so a deadlock during distributed training\n", "can show up as 100% utilization,\n", "despite literally no useful work occurring.\n", "\n", "Just switching to \n", "double precision floats, `--precision=64`,\n", "will generally lead to much higher utilization.\n", "The GPU operations take longer\n", "for roughly the same amount of CPU effort,\n", "but the added precision brings no benefit.\n", "\n", "In particular, it doesn't make for models\n", "that perform better on our correctness metrics,\n", "like loss and accuracy.\n", "\n", "Another useful yardstick to add\n", "to utilization is examples per second,\n", "which incorporates how quickly the model is processing data examples\n", "and calculating gradients.\n", "\n", "But really,\n", "the gold star is _decrease in loss per second_.\n", "This metric connects model design choices\n", "and hyperparameters with purely engineering concerns,\n", "so it disrespects abstraction barriers\n", "and doesn't generally lead to actionable recommendations,\n", "but it is, in the end, the real goal:\n", "make the loss go down faster so we get better models sooner." ] }, { "cell_type": "markdown", "metadata": { "id": "EFzPsplfdo_o" }, "source": [ "For PyTorch internals abstractly,\n", "see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "For more on performance considerations in PyTorch,\n", "see [Horace He's blog post](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "yq6-S6TC38AY" }, "source": [ "### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n", "\n", "One of the most important features for making\n", "PyTorch run quickly is the\n", "`MultiprocessingDataLoader`,\n", "which executes batching of data in a separate process\n", "from the forwards and backwards passes.\n", "\n", "By default in PyTorch,\n", "this feature is actually turned off,\n", "via the `DataLoader` argument `num_workers`\n", "having a default value of `0`,\n", "but we set the `DEFAULT_NUM_WORKERS`\n", "to a value based on the number of CPUs\n", "available on the system running the code.\n", "\n", "Re-run the profiling cell,\n", "but set `num_workers` to `0`\n", "to turn off multiprocessing.\n", "\n", "Compare and contrast the two traces,\n", "both for total runtime\n", "(see the time axis at the top of the trace)\n", "and for utilization.\n", "\n", "If you're unable to run the profiles,\n", "see the results\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n", "which juxtaposes two traces,\n", "with in-process dataloading on the left and\n", "multiprocessing dataloading on the right." ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test." ] }, { "cell_type": "markdown", "metadata": { "id": "T2i_a5eVeIoA" }, "source": [ "The file below incorrectly implements and then incorrectly tests\n", "a simple PyTorch utility for adding five to every entry of a tensor\n", "and then calculating the sum.\n", "\n", "Even worse, it does it with horrible style!\n", "\n", "The cells below apply our linting checks\n", "(after automatically fixing the formatting)\n", "and run the test.\n", "\n", "Fix all of the lints,\n", "implement the function correctly,\n", "and then implement some basic tests." ] }, { "cell_type": "markdown", "metadata": { "id": "wSon2fB5VVM_" }, "source": [ "- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aYiRvU4HA84t" }, "outputs": [], "source": [ "%%writefile training/fixme.py\n", "import torch\n", "from training import run_experiment\n", "from numpy import *\n", "import random\n", "from pathlib import Path\n", "\n", "\n", "\n", "\n", "def add_five_and_sum(tensor):\n", " # this function is not implemented right,\n", " # but it's supposed to add five to all tensor entries and sum them up\n", " return 1\n", "\n", "def test_add_five_and_sum():\n", " # and this test isn't right either! plus this isn't exactly a docstring\n", " all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n", " all_fives = 5 * all_ones\n", " assert False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EXJpmvuzT1w0" }, "outputs": [], "source": [ "!pre-commit run black --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SRO-oJfdUrcQ" }, "outputs": [], "source": [ "!cat training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jM8NHxVbSEQD" }, "outputs": [], "source": [ "!pre-commit run --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kj0VMBSndtkc" }, "outputs": [], "source": [ "!pytest training/fixme.py" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab05_troubleshooting.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab05/tasks/lint.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false # apply automatic formatting echo "black" pre-commit run black || FAILURE=true # check for python code style violations, see .flake8 for details echo "flake8" pre-commit run flake8 || FAILURE=true # check for shell scripting style violations and common bugs echo "shellcheck" pre-commit run shellcheck || FAILURE=true # check python types echo "mypy" pre-commit run mypy || FAILURE=true if [ "$FAILURE" = true ]; then echo "Linting failed" exit 1 fi echo "Linting passed" exit 0 ================================================ FILE: lab05/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab05/text_recognizer/callbacks/__init__.py ================================================ from .model import ModelSizeLogger from .optim import LearningRateMonitor from . import imtotext from .imtotext import ImageToTextTableLogger as ImageToTextLogger ================================================ FILE: lab05/text_recognizer/callbacks/imtotext.py ================================================ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only try: import wandb has_wandb = True except ImportError: has_wandb = False from .util import check_and_warn class ImageToTextTableLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.on_train: if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "validation/predictions") def _log_image_text_table(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx] xs = [wandb.Image(x) for x in xs] rows = zip(*[xs, gt_strs, pred_strs]) columns = ["input_image", "ground_truth_string", "predicted_string"] trainer.logger.log_table(key=key, columns=columns, data=list(rows)) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) class ImageToTextCaptionLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "validation/predictions") @rank_zero_only def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "test/predictions") def _log_image_text_caption(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx] trainer.logger.log_image(key, xs, caption=pred_strs) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) ================================================ FILE: lab05/text_recognizer/callbacks/model.py ================================================ import os from pathlib import Path import tempfile import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_only import torch from .util import check_and_warn, logging try: import torchviz has_torchviz = True except ImportError: has_torchviz = False class ModelSizeLogger(pl.Callback): """Logs information about model size (in parameters and on disk).""" def __init__(self, print_size=True): super().__init__() self.print_size = print_size @rank_zero_only def on_fit_start(self, trainer, module): self._run(trainer, module) def _run(self, trainer, module): metrics = {} metrics["mb_disk"] = self.get_model_disksize(module) metrics["nparams"] = count_params(module) if self.print_size: print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB") metrics = {f"size/{key}": value for key, value in metrics.items()} trainer.logger.log_metrics(metrics, step=-1) @staticmethod def get_model_disksize(module): """Determine the model's size on disk by saving it to disk.""" with tempfile.NamedTemporaryFile() as f: torch.save(module.state_dict(), f) size_mb = os.path.getsize(f.name) / 1e6 return size_mb class GraphLogger(pl.Callback): """Logs a compute graph as an image.""" def __init__(self, output_key="logits"): super().__init__() self.graph_logged = False self.output_key = output_key if not has_torchviz: raise ImportError("GraphLogCallback requires torchviz." "") @rank_zero_only def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx): if not self.graph_logged: try: outputs = outputs[0][0]["extra"] self.log_graph(trainer, module, outputs[self.output_key]) except KeyError: logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}") self.graph_logged = True @staticmethod def log_graph(trainer, module, outputs): if check_and_warn(trainer.logger, "log_image", "graph"): return params_dict = dict(list(module.named_parameters())) graph = torchviz.make_dot(outputs, params=params_dict) graph.format = "png" fname = Path(trainer.logger.experiment.dir) / "graph" graph.render(fname) fname = str(fname.with_suffix("." + graph.format)) trainer.logger.log_image(key="graph", images=[fname]) def count_params(module): """Counts the number of parameters in a Torch Module.""" return sum(p.numel() for p in module.parameters()) ================================================ FILE: lab05/text_recognizer/callbacks/optim.py ================================================ import pytorch_lightning as pl KEY = "optimizer" class LearningRateMonitor(pl.callbacks.LearningRateMonitor): """Extends Lightning's LearningRateMonitor with a prefix. Logs the learning rate during training. See the docs for pl.callbacks.LearningRateMonitor for details. """ def _add_prefix(self, *args, **kwargs) -> str: return f"{KEY}/" + super()._add_prefix(*args, **kwargs) ================================================ FILE: lab05/text_recognizer/callbacks/util.py ================================================ import logging logging.basicConfig(level=logging.WARNING) def check_and_warn(logger, attribute, feature): if not hasattr(logger, attribute): warn_no_attribute(feature, attribute) return True def warn_no_attribute(blocked_feature, missing_attribute): logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.") ================================================ FILE: lab05/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines from .iam_paragraphs import IAMParagraphs from .iam_lines import IAMLines from .fake_images import FakeImageData ================================================ FILE: lab05/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab05/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab05/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab05/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab05/text_recognizer/data/fake_images.py ================================================ """A fake image dataset for testing.""" import argparse import torch import torchvision from text_recognizer.data.base_data_module import BaseDataModule _NUM_SAMPLES = 512 _IMAGE_LEN = 28 _NUM_CLASSES = 10 class FakeImageData(BaseDataModule): """Fake images dataset.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.num_samples = self.args.get("num_samples", _NUM_SAMPLES) self.input_dims = (1, self.args.get("image_height", _IMAGE_LEN), self.args.get("image_width", _IMAGE_LEN)) self.num_classes = self.args.get("num_classes", _NUM_CLASSES) self.output_dims = (self.num_classes, 1) self.mapping = list(range(0, self.num_classes)) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--num_samples", type=int, default=_NUM_SAMPLES) parser.add_argument("--num_classes", type=int, default=_NUM_CLASSES) parser.add_argument("--image_height", type=int, default=_IMAGE_LEN) parser.add_argument("--image_width", type=int, default=_IMAGE_LEN) return parser def setup(self, stage: str = None) -> None: fake_dataset = torchvision.datasets.FakeData( size=self.num_samples, image_size=self.input_dims, num_classes=self.output_dims[0], transform=torchvision.transforms.ToTensor(), ) val_size = int(self.num_samples * 0.25) self.data_train, self.data_val, self.data_test = torch.utils.data.random_split( # type: ignore dataset=fake_dataset, lengths=[self.num_samples - 2 * val_size, val_size, val_size] ) ================================================ FILE: lab05/text_recognizer/data/iam.py ================================================ """Class for loading the IAM handwritten text dataset, which encompasses both paragraphs and lines, plus utilities.""" from pathlib import Path from typing import Any, cast, Dict, List, Optional import zipfile from boltons.cacheutils import cachedproperty from defusedxml import ElementTree from PIL import Image, ImageOps import toml from text_recognizer import util from text_recognizer.data.base_data_module import _download_raw_dataset, load_and_print_info import text_recognizer.metadata.iam as metadata from text_recognizer.metadata.iam_paragraphs import NEW_LINE_TOKEN METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME EXTRACTED_DATASET_DIRNAME = metadata.EXTRACTED_DATASET_DIRNAME class IAM: """A dataset of images of handwritten text written on a form underneath a typewritten prompt. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database Images are identified by their "form ID". These IDs are used to separate train, validation and test splits, as keys for dictonaries returning label and image crop region data, and more. The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (LWITLRT): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self): self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self): if self.xml_filenames: return filename = _download_raw_dataset(self.metadata, DL_DATA_DIRNAME) # type: ignore _extract_raw_dataset(filename, DL_DATA_DIRNAME) def load_image(self, id: str) -> Image.Image: """Load and return an image of an entire IAM form. The image is grayscale with white text on black background. This image will have the printed prompt text at the top, above the handwritten text. Images of individual words or lines and of whole paragraphs can be cropped out using the relevant crop region data. """ image = util.read_image_pil(self.form_filenames_by_id[id], grayscale=True) image = ImageOps.invert(image) return image def __repr__(self): """Print info about the dataset.""" info = ["IAM Dataset"] info.append(f"Total Images: {len(self.xml_filenames)}") info.append(f"Total Test Images: {len(self.test_ids)}") info.append(f"Total Paragraphs: {len(self.paragraph_string_by_id)}") num_lines = sum(len(line_regions) for line_regions in self.line_regions_by_id.items()) info.append(f"Total Lines: {num_lines}") return "\n\t".join(info) @cachedproperty def all_ids(self): """A list of all form IDs.""" return sorted([f.stem for f in self.xml_filenames]) @cachedproperty def ids_by_split(self): return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids} @cachedproperty def split_by_id(self): """A dictionary mapping form IDs to their split according to IAM Lines LWITLRT.""" split_by_id = {id_: "train" for id_ in self.train_ids} split_by_id.update({id_: "val" for id_ in self.validation_ids}) split_by_id.update({id_: "test" for id_ in self.test_ids}) return split_by_id @cachedproperty def train_ids(self): """A list of form IDs which are in the IAM Lines LWITLRT training set.""" return list(set(self.all_ids) - (set(self.test_ids) | set(self.validation_ids))) @cachedproperty def test_ids(self): """A list of form IDs from the IAM Lines LWITLRT test set.""" return _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/testset.txt") @property def xml_filenames(self) -> List[Path]: """A list of the filenames of all .xml files, which contain label information.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @cachedproperty def validation_ids(self): """A list of form IDs from IAM Lines LWITLRT validation sets 1 and 2.""" val_ids = _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset1.txt") val_ids.extend(_get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset2.txt")) return val_ids @property def form_filenames(self) -> List[Path]: """A list of the filenames of all .jpg files, which contain images of IAM forms.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def xml_filenames_by_id(self): """A dictionary mapping form IDs to their XML label information files.""" return {filename.stem: filename for filename in self.xml_filenames} @property def form_filenames_by_id(self): """A dictionary mapping form IDs to their JPEG images.""" return {filename.stem: filename for filename in self.form_filenames} @cachedproperty def line_strings_by_id(self): """A dict mapping an IAM form id to its list of line texts.""" return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def line_regions_by_id(self): """A dict mapping an IAM form id to its list of line image crop regions.""" return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def paragraph_string_by_id(self): """A dict mapping an IAM form id to its paragraph text.""" return {id: NEW_LINE_TOKEN.join(line_strings) for id, line_strings in self.line_strings_by_id.items()} @cachedproperty def paragraph_region_by_id(self): """A dict mapping an IAM form id to its paragraph image crop region.""" return { id: { "x1": min(region["x1"] for region in line_regions), "y1": min(region["y1"] for region in line_regions), "x2": max(region["x2"] for region in line_regions), "y2": max(region["y2"] for region in line_regions), } for id, line_regions in self.line_regions_by_id.items() } def _extract_raw_dataset(filename: Path, dirname: Path) -> None: print("Extracting IAM data") with util.temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zip_file: zip_file.extractall() def _get_ids_from_lwitlrt_split_file(filename: str) -> List[str]: """Get the ids from Large Writer Independent Text Line Recognition Task (LWITLRT) data split file.""" with open(filename, "r") as f: line_ids_str = f.read() line_ids = line_ids_str.split("\n") page_ids = list({"-".join(line_id.split("-")[:2]) for line_id in line_ids if line_id}) return page_ids def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace " with ".""" xml_line_elements = _get_line_elements_from_xml_file(filename) return [_get_text_from_xml_element(el) for el in xml_line_elements] def _get_text_from_xml_element(xml_element: Any) -> str: """Extract text from any XML element.""" return xml_element.attrib["text"].replace(""", '"') def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get the line region dict for each line.""" xml_line_elements = _get_line_elements_from_xml_file(filename) line_regions = [ cast(Dict[str, int], _get_region_from_xml_element(xml_elem=el, xml_path="word/cmp")) for el in xml_line_elements ] assert any(region is not None for region in line_regions), "Line regions cannot be None" # next_line_region["y1"] - prev_line_region["y2"] can be negative due to overlapping characters line_gaps_y = [ max(next_line_region["y1"] - prev_line_region["y2"], 0) for next_line_region, prev_line_region in zip(line_regions[1:], line_regions[:-1]) ] post_line_gaps_y = line_gaps_y + [2 * metadata.LINE_REGION_PADDING] pre_line_gaps_y = [2 * metadata.LINE_REGION_PADDING] + line_gaps_y return [ { "x1": region["x1"] - metadata.LINE_REGION_PADDING, "x2": region["x2"] + metadata.LINE_REGION_PADDING, "y1": region["y1"] - min(metadata.LINE_REGION_PADDING, pre_line_gaps_y[i] // 2), "y2": region["y2"] + min(metadata.LINE_REGION_PADDING, post_line_gaps_y[i] // 2), } for i, region in enumerate(line_regions) ] def _get_line_elements_from_xml_file(filename: str) -> List[Any]: """Get all line xml elements from xml file.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec return xml_root_element.findall("handwritten-part/line") def _get_region_from_xml_element(xml_elem: Any, xml_path: str) -> Optional[Dict[str, int]]: """ Get region from input xml element. The region is downsampled because the stored images are also downsampled. Parameters ---------- xml_elem xml element can be a line or word element with x, y, width, and height attributes xml_path should be "word/cmp" if xml_elem is a line element, else "cmp" """ unit_elements = xml_elem.findall(xml_path) if not unit_elements: return None return { "x1": min(int(el.attrib["x"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y1": min(int(el.attrib["y"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "x2": max(int(el.attrib["x"]) + int(el.attrib["width"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y2": max(int(el.attrib["y"]) + int(el.attrib["height"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, } if __name__ == "__main__": load_and_print_info(IAM) ================================================ FILE: lab05/text_recognizer/data/iam_lines.py ================================================ """A dataset of lines of handwritten text derived from the IAM dataset.""" import argparse import json from pathlib import Path from typing import Sequence import numpy as np from PIL import Image, ImageFile from text_recognizer import util from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.line import IAMLineStem ImageFile.LOAD_TRUNCATED_IMAGES = True PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR class IAMLines(BaseDataModule): """Lines of text pulled from the IAM Handwriting database.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true") == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = IAMLineStem() self.trainval_transform = IAMLineStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if PROCESSED_DATA_DIRNAME.exists(): return print("Cropping IAM line regions...") iam = IAM() iam.prepare_data() crops_train, labels_train = generate_line_crops_and_labels(iam, "train") crops_val, labels_val = generate_line_crops_and_labels(iam, "val") crops_test, labels_test = generate_line_crops_and_labels(iam, "test") shapes = np.array([crop.size for crop in crops_train + crops_val + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] print("Saving images, labels, and statistics...") save_images_and_labels(crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_val, labels_val, "val", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt", "w") as file: file.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt") as file: max_aspect_ratio = float(file.read()) image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) assert image_width <= metadata.IMAGE_WIDTH if stage == "fit" or stage is None: x_train, labels_train = load_processed_crops_and_labels("train", PROCESSED_DATA_DIRNAME) y_train = convert_strings_to_labels(labels_train, self.inverse_mapping, length=self.output_dims[0]) self.data_train = BaseDataset(x_train, y_train, transform=self.trainval_transform) x_val, labels_val = load_processed_crops_and_labels("val", PROCESSED_DATA_DIRNAME) y_val = convert_strings_to_labels(labels_val, self.inverse_mapping, length=self.output_dims[0]) self.data_val = BaseDataset(x_val, y_val, transform=self.trainval_transform) # quick check: do we have the right sequence lengths? assert self.output_dims[0] >= max([len(_) for _ in labels_train]) + 2 # Add 2 for start/end tokens. assert self.output_dims[0] >= max([len(_) for _ in labels_val]) + 2 # Add 2 for start/end tokens. if stage == "test" or stage is None: x_test, labels_test = load_processed_crops_and_labels("test", PROCESSED_DATA_DIRNAME) y_test = convert_strings_to_labels(labels_test, self.inverse_mapping, length=self.output_dims[0]) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) assert self.output_dims[0] >= max([len(_) for _ in labels_test]) + 2 def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Lines Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def generate_line_crops_and_labels(iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR): """Create both cropped lines and associated labels from IAM, with resizing by default""" crops, labels = [], [] for iam_id in iam.ids_by_split[split]: labels += iam.line_strings_by_id[iam_id] image = iam.load_image(iam_id) for line in iam.line_regions_by_id[iam_id]: coords = [line[point] for point in ["x1", "y1", "x2", "y2"]] crop = image.crop(coords) crop = resize_image(crop, scale_factor=scale_factor) crops.append(crop) assert len(crops) == len(labels) return crops, labels def save_images_and_labels(crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path): (data_dirname / split).mkdir(parents=True, exist_ok=True) with open(data_dirname / split / "_labels.json", "w") as f: json.dump(labels, f) for ind, crop in enumerate(crops): crop.save(data_dirname / split / f"{ind}.png") def load_processed_crops_and_labels(split: str, data_dirname: Path): """Load line crops and labels for given split from processed directory.""" crops = load_processed_line_crops(split, data_dirname) labels = load_processed_line_labels(split, data_dirname) assert len(crops) == len(labels) return crops, labels def load_processed_line_crops(split: str, data_dirname: Path): """Load line crops for given split from processed directory.""" crop_filenames = sorted((data_dirname / split).glob("*.png"), key=lambda filename: int(Path(filename).stem)) crops = [util.read_image_pil(filename, grayscale=True) for filename in crop_filenames] return crops def load_processed_line_labels(split: str, data_dirname: Path): """Load line labels for given split from processed directory.""" with open(data_dirname / split / "_labels.json") as file: labels = json.load(file) return labels if __name__ == "__main__": load_and_print_info(IAMLines) ================================================ FILE: lab05/text_recognizer/data/iam_paragraphs.py ================================================ """IAM Paragraphs Dataset class.""" import argparse import json from pathlib import Path from typing import Callable, Dict, Optional, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.paragraph import ParagraphStem IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME class IAMParagraphs(BaseDataModule): """IAM Handwriting database paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true").lower() == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = ParagraphStem() self.trainval_transform = ParagraphStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): return rank_zero_info( "IAMParagraphs.prepare_data: Cropping IAM paragraph regions and saving them along with labels..." ) iam = IAM() iam.prepare_data() properties = {} for split in ["train", "val", "test"]: crops, labels = get_paragraph_crops_and_labels(iam=iam, split=split) save_crops_and_labels(crops=crops, labels=labels, split=split) properties.update( { id_: { "crop_shape": crops[id_].size[::-1], "label_length": len(label), "num_lines": _num_lines(label), } for id_, label in labels.items() } ) with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: def _load_dataset(split: str, transform: Callable) -> BaseDataset: crops, labels = load_processed_crops_and_labels(split) Y = convert_strings_to_labels(strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]) return BaseDataset(crops, Y, transform=transform) rank_zero_info(f"IAMParagraphs.setup({stage}): Loading IAM paragraph regions and lines...") validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) if stage == "fit" or stage is None: self.data_train = _load_dataset(split="train", transform=self.trainval_transform) self.data_val = _load_dataset(split="val", transform=self.transform) if stage == "test" or stage is None: self.data_test = _load_dataset(split="test", transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def validate_input_and_output_dimensions( input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] ) -> None: """Validate input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] # Add 2 because of start and end tokens assert output_dims is not None and output_dims[0] >= properties["label_length"]["max"] + 2 def get_paragraph_crops_and_labels( iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR ) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: """Create IAM paragraph crops and labels for a given split, with resizing.""" crops = {} labels = {} for iam_id in iam.ids_by_split[split]: image = iam.load_image(iam_id) para_region = iam.paragraph_region_by_id[iam_id] crops[iam_id] = image.crop([para_region[_] for _ in ["x1", "y1", "x2", "y2"]]) crops[iam_id] = resize_image(crops[iam_id], scale_factor=scale_factor) labels[iam_id] = iam.paragraph_string_by_id[iam_id] assert len(crops) == len(labels) return crops, labels def save_crops_and_labels(crops: Dict[str, Image.Image], labels: Dict[str, str], split: str): """Save crops, labels and shapes of crops of a split.""" (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with open(_labels_filename(split), "w") as f: json.dump(labels, f, indent=4) for id_, crop in crops.items(): crop.save(_crop_filename(id_, split)) def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: """Load processed crops and labels for given split.""" with open(_labels_filename(split), "r") as f: labels = json.load(f) sorted_ids = sorted(labels.keys()) ordered_crops = [Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids] ordered_labels = [labels[id_] for id_ in sorted_ids] assert len(ordered_crops) == len(ordered_labels) return ordered_crops, ordered_labels def get_dataset_properties() -> dict: """Return properties describing the overall dataset.""" with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: properties = json.load(f) def _get_property_values(key: str) -> list: return [_[key] for _ in properties.values()] crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count(NEW_LINE_TOKEN) + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs) ================================================ FILE: lab05/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab05/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab05/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab05/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel from .transformer import TransformerLitModel ================================================ FILE: lab05/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy from .metrics import CharacterErrorRate OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def add_on_first_batch(self, metrics, outputs, batch_idx): if batch_idx == 0: outputs.update(metrics) def add_on_logged_batches(self, metrics, outputs): if self.is_logged_batch: outputs.update(metrics) def is_logged_batch(self): if self.trainer is None: return False else: return self.trainer._logger_connector.should_update_logs class BaseImageToTextLitModel(BaseLitModel): # pylint: disable=too-many-ancestors """Base class for ImageToText models in PyTorch Lightning.""" def __init__(self, model, args: argparse.Namespace = None): super().__init__(model, args) self.model = model self.args = vars(args) if args is not None else {} self.inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} self.start_index = self.inverse_mapping[""] self.end_index = self.inverse_mapping[""] self.padding_index = self.inverse_mapping["

"] self.ignore_tokens = [self.start_index, self.end_index, self.padding_index] self.val_cer = CharacterErrorRate(self.ignore_tokens) self.test_cer = CharacterErrorRate(self.ignore_tokens) ================================================ FILE: lab05/text_recognizer/lit_models/metrics.py ================================================ """Special-purpose metrics for tracking our model performance.""" from typing import Sequence import torch import torchmetrics class CharacterErrorRate(torchmetrics.CharErrorRate): """Character error rate metric, allowing for tokens to be ignored.""" def __init__(self, ignore_tokens: Sequence[int], *args): super().__init__(*args) self.ignore_tokens = set(ignore_tokens) def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()] targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()] super().update(preds_l, targets_l) def test_character_error_rate(): metric = CharacterErrorRate([0, 1]) X = torch.tensor( [ [0, 2, 2, 3, 3, 1], # error will be 0 [0, 2, 1, 1, 1, 1], # error will be .75 [0, 2, 2, 4, 4, 1], # error will be .5 ] ) Y = torch.tensor( [ [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], ] ) metric(X, Y) assert metric.compute() == sum([0, 0.75, 0.5]) / 3 if __name__ == "__main__": test_character_error_rate() ================================================ FILE: lab05/text_recognizer/lit_models/transformer.py ================================================ """An encoder-decoder Transformer model""" from typing import List, Sequence import torch from .base import BaseImageToTextLitModel from .util import replace_after class TransformerLitModel(BaseImageToTextLitModel): """ Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module. The module must implement an encode and decode method, and the forward method should be the forward pass during production inference. """ def __init__(self, model, args=None): super().__init__(model, args) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index) def forward(self, x): return self.model(x) def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Uses provided sequence y as guide for non-autoregressive encoding-decoding of x. Parameters ---------- x Batch of images to be encoded. See self.model.encode for shape information. y Batch of ground truth output sequences. Returns ------- torch.Tensor (B, C, Sy) logits """ x = self.model.encode(x) output = self.model.decode(x, y) # (Sy, B, C) return output.permute(1, 2, 0) # (B, C, Sy) def training_step(self, batch, batch_idx): x, y = batch logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("train/loss", loss) outputs = {"loss": loss} if self.is_logged_batch(): preds = self.get_preds(logits) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs}) return outputs def validation_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("test/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def map(self, ks: Sequence[int], ignore: bool = True) -> str: """Maps an iterable of integers to a string using the lit model's mapping.""" if ignore: return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens]) else: return "".join([self.mapping[k] for k in ks]) def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]: """Maps a list of lists of integers to a list of strings using the lit model's mapping.""" return [self.map(k, ignore) for k in ks] def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor: """Converts logit-like Tensors into prediction indices, optionally overwritten after end token index. Parameters ---------- logitlikes (B, C, Sy) Tensor with classes as second dimension. The largest value is the one whose index we will return. Logits, logprobs, and probs are all acceptable. replace_after_end Whether to replace values after the first appearance of the end token with the padding token. Returns ------- torch.Tensor (B, Sy) Tensor of integers in [0, C-1] representing predictions. """ raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy) if replace_after_end: return replace_after(raw, self.end_index, self.padding_index) # (B, Sy) else: return raw # (B, Sy) ================================================ FILE: lab05/text_recognizer/lit_models/util.py ================================================ from typing import Union import torch def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: """Return indices of first appearance of element in x, collapsing along dim. Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 Parameters ---------- x One or two-dimensional Tensor to search for element. element Item to search for inside x. dim Dimension of Tensor to collapse over. Returns ------- torch.Tensor Indices where element occurs in x. If element is not found, return length of x along dim. One dimension smaller than x. Raises ------ ValueError if x is not a 1 or 2 dimensional Tensor Examples -------- >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3) tensor([2, 1, 3, 0]) >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0) tensor(0) """ if x.dim() > 2 or x.dim() == 0: raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}") matches = x == element first_appearance_mask = (matches.cumsum(dim) == 1) & matches does_match, match_index = first_appearance_mask.max(dim) first_inds = torch.where(does_match, match_index, x.shape[dim]) return first_inds def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor: """Replace all values in each row of 2d Tensor x after the first appearance of element with replace. Parameters ---------- x Two-dimensional Tensor (shape denoted (B, S)) to replace values in. element Item to search for inside x. replace Item that replaces entries that appear after element. Returns ------- outs New Tensor of same shape as x with values after element replaced. Examples -------- >>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4) tensor([[1, 2, 3], [2, 3, 4], [1, 1, 1], [3, 4, 4]]) """ first_appearances = first_appearance(x, element, dim=1) # (B,) indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,) outs = torch.where( indices[None, :] <= first_appearances[:, None], # if index is before first appearance x, # return the value from x replace, # otherwise, return the replacement value ) return outs # (B, S) ================================================ FILE: lab05/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab05/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab05/text_recognizer/metadata/iam.py ================================================ import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates ================================================ FILE: lab05/text_recognizer/metadata/iam_lines.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines" IMAGE_SCALE_FACTOR = 2 CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (89, 1) MAPPING = emnist.MAPPING ================================================ FILE: lab05/text_recognizer/metadata/iam_paragraphs.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN] IMAGE_SCALE_FACTOR = 2 IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640 IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) MAX_LABEL_LENGTH = 682 DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) ================================================ FILE: lab05/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab05/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab05/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple from .resnet_transformer import ResnetTransformer from .line_cnn_transformer import LineCNNTransformer ================================================ FILE: lab05/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab05/text_recognizer/models/line_cnn.py ================================================ """Basic building blocks for convolutional models over lines of text.""" import argparse import math from typing import Any, Dict, Tuple, Union import torch from torch import nn import torch.nn.functional as F # Common type hints Param2D = Union[int, Tuple[int, int]] CONV_DIM = 32 FC_DIM = 512 FC_DROPOUT = 0.2 WINDOW_WIDTH = 16 WINDOW_STRIDE = 8 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__( self, input_channels: int, output_channels: int, kernel_size: Param2D = 3, stride: Param2D = 1, padding: Param2D = 1, ) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class LineCNN(nn.Module): """ Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits """ def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.args = vars(args) if args is not None else {} self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] _C, H, _W = data_config["input_dims"] conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) # Input is (1, H, W) self.convs = nn.Sequential( ConvBlock(1, conv_dim), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim, stride=2), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim * 2, stride=2), ConvBlock(conv_dim * 2, conv_dim * 2), ConvBlock(conv_dim * 2, conv_dim * 4, stride=2), ConvBlock(conv_dim * 4, conv_dim * 4), ConvBlock( conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0 ), ) self.fc1 = nn.Linear(fc_dim, fc_dim) self.dropout = nn.Dropout(fc_dropout) self.fc2 = nn.Linear(fc_dim, self.num_classes) self._init_weights() def _init_weights(self): """ Initialize weights in a better way than default. See https://github.com/pytorch/pytorch/issues/18182 """ for m in self.modules(): if type(m) in { nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear, }: nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu") if m.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(m.bias, -bound, bound) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the LineCNN to a black-and-white input image. Parameters ---------- x (B, 1, H, W) input image Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and self.window_width C is self.num_classes """ _B, _C, _H, _W = x.shape x = self.convs(x) # (B, FC_DIM, 1, Sx) x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM) x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM) x = self.dropout(x) x = self.fc2(x) # (B, S, C) x = x.permute(0, 2, 1) # -> (B, C, S) if self.limit_output_length: x = x[:, :, : self.output_length] return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab05/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab05/text_recognizer/models/line_cnn_transformer.py ================================================ """Model that combines a LineCNN with a Transformer model for text prediction.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .line_cnn import LineCNN from .transformer_util import generate_square_subsequent_mask, PositionalEncoding TF_DIM = 256 TF_FC_DIM = 256 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 class LineCNNTransformer(nn.Module): """Process the line through a CNN and process the resulting sequence with a Transformer decoder.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # Instantiate LineCNN with "num_classes" set to self.dim data_config_for_line_cnn = {**data_config} data_config_for_line_cnn["mapping"] = list(range(self.dim)) self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args) # LineCNN outputs (B, E, S) log probs, with E == dim self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.pos_encoder = PositionalEncoding(d_model=self.dim) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (Sx, B, E) logits """ x = self.line_cnn(x) # (B, E, Sx) x = x * math.sqrt(self.dim) x = x.permute(2, 0, 1) # (Sx, B, E) x = self.pos_encoder(x) # (Sx, B, E) return x def decode(self, x, y): """Decode a batch of encoded images x using preceding ground truth y. Parameters ---------- x (Sx, B, E) image encoded as a sequence y (B, Sy) with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) logits """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output def forward(self, x: torch.Tensor) -> torch.Tensor: """Predict sequences of tokens from input images auto-regressively. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1:] # Set the last output token # Set all tokens after end token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) @staticmethod def add_to_argparse(parser): LineCNN.add_to_argparse(parser) parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab05/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab05/text_recognizer/models/resnet_transformer.py ================================================ """Model combining a ResNet with a Transformer for image-to-sequence tasks.""" import argparse import math from typing import Any, Dict import torch from torch import nn import torchvision from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage TF_DIM = 256 TF_FC_DIM = 1024 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 RESNET_DIM = 512 # hard-coded class ResnetTransformer(nn.Module): """Pass an image through a Resnet and decode the resulting embedding with a Transformer.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) self.mapping = data_config["mapping"] inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # ## Encoder part - should output vector sequence of length self.dim per sample resnet = torchvision.models.resnet18(weights=None) self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) # encoder_projection will output (B, dim, _H, _W) logits self.enc_pos_encoder = PositionalEncodingImage( d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] ) # Max (Ho, Wo) # ## Decoder part self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def forward(self, x: torch.Tensor) -> torch.Tensor: """Autoregressively produce sequences of labels from input images. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- output_tokens (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1] # Set the last output token # Early stopping of prediction loop to speed up prediction if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): break # Set all tokens after end or padding token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") if self.encoder_projection.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.encoder_projection.bias, -bound, bound) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- (Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps """ _B, C, _H, _W = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo return x def decode(self, x, y): """Decode a batch of encoded images x with guiding sequences y. During autoregressive inference, the guiding sequence will be previous predictions. During training, the guiding sequence will be the ground truth. Parameters ---------- x (Sx, B, E) images encoded as sequences of embeddings y (B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) batch of logit sequences """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.dec_pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output @staticmethod def add_to_argparse(parser): parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab05/text_recognizer/models/transformer_util.py ================================================ """Position Encoding and other utilities for Transformers.""" import math import torch from torch import Tensor import torch.nn as nn class PositionalEncodingImage(nn.Module): """ Module used to add 2-D positional encodings to the feature-map produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor: pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """pytorch.nn.module.forward""" # x.shape = (B, d_model, H, W) assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding(torch.nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None: super().__init__() self.dropout = torch.nn.Dropout(p=dropout) pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_len: int) -> torch.Tensor: pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: # x.shape = (S, B, d_model) assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) def generate_square_subsequent_mask(size: int) -> torch.Tensor: """Generate a triangular (size, size) mask.""" mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask ================================================ FILE: lab05/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab05/text_recognizer/stems/line.py ================================================ import random from PIL import Image from torchvision import transforms import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.image import ImageStem class LineStem(ImageStem): """A stem for handling images containing a line of text.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.5, 1)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "translate": (0, 0.05), "scale": (0.4, 1.1), "shear": (-40, 50), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } if augment: self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] ) class IAMLineStem(ImageStem): """A stem for handling images containing lines of text from the IAMLines dataset.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() def embed_crop(crop, augment=augment): # crop is PIL.image of dtype="L" (so values range from 0 -> 255) image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) # Resize crop crop_width, crop_height = crop.size new_crop_height = metadata.IMAGE_HEIGHT new_crop_width = int(new_crop_height * (crop_width / crop_height)) if augment: # Add random stretching new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR) # Embed in the image x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) y = metadata.IMAGE_HEIGHT - new_crop_height image.paste(crop_resized, (x, y)) return image if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.8, 1.6)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 1, "shear": (-30, 20), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } pil_transforms_list = [transforms.Lambda(embed_crop)] if augment: pil_transforms_list += [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] self.pil_transforms = transforms.Compose(pil_transforms_list) ================================================ FILE: lab05/text_recognizer/stems/paragraph.py ================================================ """IAMParagraphs Stem class.""" import torchvision.transforms as transforms import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.image import ImageStem IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH class ParagraphStem(ImageStem): """A stem for handling images that contain a paragraph of text.""" def __init__( self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None, random_perspective_kwargs=None, gaussian_blur_kwargs=None, sharpness_kwargs=None, ): super().__init__() if not augment: self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)]) else: if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "shear": 6, "scale": (0.95, 1), "interpolation": transforms.InterpolationMode.BILINEAR, } if random_perspective_kwargs is None: random_perspective_kwargs = { "distortion_scale": 0.2, "p": 0.5, "interpolation": transforms.InterpolationMode.BILINEAR, } if gaussian_blur_kwargs is None: gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} if sharpness_kwargs is None: sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} # IMAGE_SHAPE is (576, 640) self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomCrop( size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant" ), transforms.RandomAffine(**random_affine_kwargs), transforms.RandomPerspective(**random_perspective_kwargs), transforms.GaussianBlur(**gaussian_blur_kwargs), transforms.RandomAdjustSharpness(**sharpness_kwargs), ] ) ================================================ FILE: lab05/text_recognizer/tests/test_callback_utils.py ================================================ """Tests for the text_recognizer.callbacks.util module.""" import random import string import tempfile import pytorch_lightning as pl from text_recognizer.callbacks.util import check_and_warn def test_check_and_warn_simple(): """Test the success and failure in the case of a simple class we control.""" class Foo: pass # a class with no special attributes letters = string.ascii_lowercase random_attribute = "".join(random.choices(letters, k=10)) assert check_and_warn(Foo(), random_attribute, "random feature") assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects") def test_check_and_warn_tblogger(): """Test that we return a truthy value when trying to log tables with TensorBoard. We added check_and_warn in order to prevent a crash if this happens. """ tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory()) assert check_and_warn(tblogger, "log_table", "tables") def test_check_and_warn_wandblogger(): """Test that we return a falsy value when we try to log tables with W&B. In adding check_and_warn, we don't want to block the feature in the happy path. """ wandblogger = pl.loggers.WandbLogger(anonymous=True) assert not check_and_warn(wandblogger, "log_table", "tables") ================================================ FILE: lab05/text_recognizer/tests/test_iam.py ================================================ """Test for data.iam module.""" from text_recognizer.data.iam import IAM def test_iam_parsed_lines(): """Tests that we retrieve the same number of line labels and line image cropregions.""" iam = IAM() iam.prepare_data() for iam_id in iam.all_ids: assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id]) def test_iam_data_splits(): """Fails when any identifiers are shared between training, test, or validation.""" iam = IAM() iam.prepare_data() assert not set(iam.train_ids) & set(iam.validation_ids) assert not set(iam.train_ids) & set(iam.test_ids) assert not set(iam.validation_ids) & set(iam.test_ids) ================================================ FILE: lab05/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab05/training/__init__.py ================================================ ================================================ FILE: lab05/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import callbacks as cb from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--wandb", action="store_true", default=False, help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.", ) parser.add_argument( "--profile", action="store_true", default=False, help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.", ) parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.loss == "transformer": lit_model_class = lit_models.TransformerLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" if goldstar_metric == "validation/cer": filename_format += "-validation.cer={validation/cer:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.wandb: logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train") logger.watch(model, log_freq=max(100, args.log_every_n_steps)) logger.log_hyperparams(vars(args)) experiment_dir = logger.experiment.dir callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) if args.wandb and args.loss in ("transformer",): callbacks.append(cb.ImageToTextLogger()) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) if args.profile: sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0) profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir) profiler.STEP_FUNCTIONS = {"training_step"} # only profile training else: profiler = pl.profiler.PassThroughProfiler() trainer.profiler = profiler trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") if args.wandb: rank_zero_info("Best model also uploaded to W&B ") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab05/training/tests/test_memorize_iam.sh ================================================ #!/bin/bash set -uo pipefail set +e # tests whether we can achieve a criterion loss # on a single batch within a certain number of epochs FAILURE=false # constants and CLI args set by aiming for <5 min test on commodity GPU, # including data download step MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash CRITERION="${2:-1.0}" # train on GPU if it's available GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))') python ./training/run_experiment.py \ --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \ --augment_data false --tf_dropout 0.0 \ --gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \ --log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true if [ "$FAILURE" = true ]; then echo "Memorization test failed at loss criterion $CRITERION" exit 1 fi echo "Memorization test passed at loss criterion $CRITERION" exit 0 ================================================ FILE: lab05/training/tests/test_run_experiment.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false echo "running full loop test with CNN on fake data" python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true echo "running fast_dev_run test of real model class on real data" python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \ --fast_dev_run --num_sanity_val_steps 0 \ --num_workers 1 || FAILURE=true if [ "$FAILURE" = true ]; then echo "Test for run_experiment.py failed" exit 1 fi echo "Tests for run_experiment.py passed" exit 0 ================================================ FILE: lab05/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: lab06/.flake8 ================================================ [flake8] select = ANN,B,B9,BLK,C,D,E,F,I,S,W # only check selected error codes max-complexity = 12 # C9 - flake8 McCabe Complexity checker -- threshold max-line-length = 120 # E501 - flake8 -- line length too long, actually handled by black extend-ignore = # E W - flake8 PEP style check E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks # S - flake8-bandit safety check S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password # ANN - flake8-annotations type annotation check ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some # D1 - flake8-docstrings docstring style check D100,D102,D103,D104,D105, # missing docstrings # D2 D4 - flake8-docstrings docstring style check D200,D205,D400,D401, # whitespace issues and first line content # DAR - flake8-darglint docstring correctness check DAR103, # mismatched or missing type in docstring application-import-names = app_gradio,text_recognizer,tests,training # flake8-import-order: which names are first party? import-order-style = google # flake8-import-order: which import order style guide do we use? docstring-convention = numpy # flake8-docstrings: which docstring style guide do we use? strictness = short # darglint: how "strict" are we with docstring completeness? docstring-style = numpy # darglint: which docstring style guide do we use? suppress-none-returning = true # flake8-annotations: do we allow un-annotated Nones in returns? mypy-init-return = true # flake8-annotations: do we allow init to have no return annotation? per-file-ignores = # list of case-by-case ignores, see files for details */__init__.py:F401,I */data/*.py:DAR data/*.py:F,I *text_recognizer/util.py:DAR101,F401 *training/run_experiment.py:I202 *app_gradio/app.py:I202 ================================================ FILE: lab06/.github/workflows/pre-commit.yml ================================================ name: pre-commit on: pull_request: push: # allows this Action to be triggered manually workflow_dispatch: jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: python-version: '3.10' - uses: pre-commit/action@v3.0.0 ================================================ FILE: lab06/.pre-commit-config.yaml ================================================ repos: # a set of useful Python-based pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: # list of definitions and supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace # removes any whitespace at the ends of lines - id: check-toml # check toml syntax by loading all toml files - id: check-yaml # check yaml syntax by loading all yaml files - id: check-json # check-json syntax by loading all json files - id: check-merge-conflict # check for files with merge conflict strings args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge - id: check-added-large-files # check that no "large" files have been added args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server - id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.) - id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.) # black python autoformatting - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black # additional configuration of black in pyproject.toml # flake8 python linter with all the fixins - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08) additional_dependencies: [ flake8-bandit, flake8-bugbear, flake8-docstrings, flake8-import-order, darglint, mypy, pycodestyle, pydocstyle] args: ["--config", ".flake8"] # additional configuration of flake8 and extensions in .flake8 # shellcheck-py for linting shell files - repo: https://github.com/shellcheck-py/shellcheck-py rev: v0.8.0.4 hooks: - id: shellcheck ================================================ FILE: lab06/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab06/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab06/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab06/notebooks/lab03_transformers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 03: Transformers and Paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The fundamental reasons why the Transformer is such\n", "a powerful and popular architecture\n", "- Core intuitions for the behavior of Transformer architectures\n", "- How to use a convolutional encoder and a Transformer decoder to recognize\n", "entire paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 3\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Transformers?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal in building a text recognizer is to take a two-dimensional image\n", "and convert it into a one-dimensional sequence of characters\n", "from some alphabet." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional neural networks,\n", "discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n", "are great at encoding images,\n", "taking them from their raw pixel values\n", "to a more semantically meaningful numerical representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But how do we go from that to a sequence of letters?\n", "And what's especially tricky:\n", "the number of letters in an image is separable from its size.\n", "A screenshot of this document has a much higher density of letters\n", "than a close-up photograph of a piece of paper.\n", "How do we get a _variable-length_ sequence of letters,\n", "where the length need have nothing to do with the size of the input tensor?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n", "they were\n", "[originally introduced](https://arxiv.org/abs/1706.03762)\n", "for transforming one sequence into another,\n", "as in machine translation.\n", "This makes them a natural fit for processing language.\n", "\n", "But they have also found success in other domains --\n", "at the time of this writing, large transformers\n", "dominate the\n", "[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n", "that has become a de facto standard for comparing models\n", "and are finding\n", "[application in reinforcement learning](https://arxiv.org/abs/2106.01345)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we will use a Transformer as a key component of our final architecture:\n", "we will encode our input images with a CNN\n", "and then read them out into a text sequence with a Transformer.\n", "\n", "Before trying out this new model,\n", "let's first get an understanding of why the Transformer architecture\n", "has become so popular by walking through its history\n", "and then get some intuition for how it works\n", "by looking at some\n", "[recent work](https://transformer-circuits.pub/)\n", "on explaining the behavior of both toy models and state-of-the-art language models." ] }, { "cell_type": "markdown", "metadata": { "id": "kmKqjbvd-Mj3" }, "source": [ "## Why not convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "SRqkUMdM-OxU" }, "source": [ "In the ancient beforetimes (i.e. 2016),\n", "the best models for natural language processing were all\n", "_recurrent_ neural networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional networks were also occasionally used,\n", "but they suffered from a serious issue:\n", "their architectural biases don't fit text.\n", "\n", "First, _translation equivariance_ no longer holds.\n", "The beginning of a piece of text is often quite different from the middle,\n", "so the absolute position matters.\n", "\n", "Second, _locality_ is not as important in language.\n", "The name of a character that hasn't appeared in thousands of pages\n", "can become salient when someone asks, \"Whatever happened to\n", "[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n", "\n", "Consider interpreting a piece of text like the Python code below:\n", "```python\n", "def do(arg1, arg2, arg3):\n", " a = arg1 + arg2\n", " b = arg3[:3]\n", " c = a * b\n", " return c\n", "\n", "print(do(1, 1, \"ayy lmao\"))\n", "```\n", "\n", "After a `(` we expect a `)`,\n", "but possibly very long afterwards,\n", "[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n", "and similarly we expect a `]` at some point after a `[`.\n", "\n", "For translation variance, consider\n", "that we interpret `*` not by\n", "comparing it to its neighbors\n", "but by looking at `a` and `b`.\n", "We mix knowledge learned through experience\n", "with new facts learned while reading --\n", "also known as _in-context learning_.\n", "\n", "In a longer text,\n", "[e.g. the one you are reading now](./lab03_transformers.ipynb),\n", "the translation variance of text is clearer.\n", "Every lab notebook begins with the same header,\n", "setting up the environment,\n", "but that header never appears elsewhere in the notebook.\n", "Later positions need to be processed in terms of the previous entries.\n", "\n", "Unlike an image, we cannot simply rotate or translate our \"camera\"\n", "and get a new valid text.\n", "[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n", "that can be read without regard to position." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The field of formal language theory,\n", "which has deep mutual influence with computer science,\n", "gives one way of explaining the issues with convolutional networks:\n", "they can only understand languages with _finite contexts_,\n", "where all the information can be found within a finite window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The immediate solution, drawing from the connections to computer science, is\n", "[recursion](https://www.google.com/search?q=recursion).\n", "A network whose output on the final entry of the sequence is a recursive function\n", "of all the previous entries can build up knowledge\n", "as it reads the sequence and treat early entries quite differently than it does late ones." ] }, { "cell_type": "markdown", "metadata": { "id": "aa6cbTlImkEh" }, "source": [ "In pseudo-code, such a _recurrent neural network_ module might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "lKtBoPnglPrW" }, "source": [ "```python\n", "def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n", " next_inputs = input_module(xs[-1])\n", " next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n", " return output_module(next_inputs, next_hiddens)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "IbJPSMnEm516" }, "source": [ "If you've had formal computer science training,\n", "then you may be familiar with the power of recursion,\n", "e.g. the\n", "[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n", "that gave its name to the now much better-known\n", "[startup incubator](https://www.ycombinator.com/).\n", "\n", "The particular form of recursion used by\n", "recurrent neural networks implements a\n", "[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n", "\n", "> If you've know a lot of computer science,\n", "you might be concerned by this connection.\n", "What about other\n", "[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n", "Where are the neural network architectures for differentiable\n", "[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n", "Check out Graph Neural Networks,\n", "[which implement dynamic programming](https://arxiv.org/abs/2203.15544)." ] }, { "cell_type": "markdown", "metadata": { "id": "63mMTbEBpVuE" }, "source": [ "Recurrent networks are able to achieve\n", "[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n", "\n", "There are many popular recurrent architectures,\n", "from the beefy and classic\n", "[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n", "and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n", "([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n", "all of which have roughly similar capabilities but\n", "[some of which are easier to train](https://arxiv.org/abs/1611.09913)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwQHVTIslOku" }, "source": [ "In the same sense that MLPs can model \"any\" feedforward function,\n", "in principle even basic RNNs\n", "[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n", "\n", "In particular they can model any\n", "[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n", "which is a formal way of saying that they can in principle\n", "do anything a computer is capable of doing.\n", "\n", "The question is then..." ] }, { "cell_type": "markdown", "metadata": { "id": "3J8EoGN3pu7P" }, "source": [ "## Why aren't we all using RNNs?" ] }, { "cell_type": "markdown", "metadata": { "id": "TDwNWaevpt_3" }, "source": [ "The guarantees that MLPs can model any function\n", "or that RNNs can model Turing machines\n", "provide decent intuition but are not directly practically useful.\n", "Among other reasons, they don't guarantee learnability --\n", "that starting from random parameters we can find the parameters\n", "that implement a given function.\n", "The\n", "[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n", "than would seem from basic theoretical and empirical analysis.\n", "\n", "One way of understanding capacity to model language is\n", "[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n", "In this model of formal languages,\n", "Turing machines sit at the top\n", "([practically speaking](https://arxiv.org/abs/math/0209332)).\n", "\n", "With better mathematical models,\n", "RNNs and LSTMs can be shown to be\n", "[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n", "with RNNs looking more like\n", "[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n", "and LSTMs coming in\n", "[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n", "\n", "More controversially:\n", "the Chomsky hierarchy is great for understanding syntax and grammar,\n", "which makes it great for building parsers\n", "and working with formal languages,\n", "but the goal in _natural_ language processing is to understand _natural_ language.\n", "Most humans' natural language is far from strictly grammatical,\n", "but that doesn't mean it is nonsense.\n", "\n", "And to really \"understand\" language means\n", "to understand its semantic content, which is fuzzy.\n", "The most important thing for handling the fuzzy semantic content\n", "of language is not whether you can recall\n", "[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n", "but whether you can model probabilistic relationships between concepts\n", "in addition to grammar and syntax." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These both leave theoretical room for improvement over current recurrent\n", "language and sequence models.\n", "\n", "But the real cause of the rise of Transformers is that..." ] }, { "cell_type": "markdown", "metadata": { "id": "Dsu1ebvAp-3Z" }, "source": [ "## Transformers are designed to train fast at scale on contemporary hardware." ] }, { "cell_type": "markdown", "metadata": { "id": "c4abU5adsPGs" }, "source": [ "The Transformer architecture has several important features,\n", "discussed below,\n", "but one of the most important reasons why it is successful\n", "is because it can be more easily trained at scale.\n", "\n", "This scalability is the focus of the discussion in the paper\n", "that introduced the architecture,\n", "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", "and\n", "[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n", "\n", "The recursion in RNNs is inherently sequential:\n", "the dependence on the outputs from earlier in the sequence\n", "means computations within an example cannot be parallelized.\n", "\n", "So RNNs must batch across examples to scale,\n", "but as sequence length grows this hits memorybandwidth limits.\n", "Serving up large batches quickly with good randomness guarantees\n", "is also hard to optimize,\n", "especially in distributed settings.\n", "\n", "The Transformer architecture,\n", "on the other hand,\n", "can be readily parallelized within a single example sequence,\n", "in addition to parallelization across batches.\n", "This can lead to massive performance gains for a fixed scale,\n", "which means larger, higher capacity models\n", "can be trained on larger datasets." ] }, { "cell_type": "markdown", "metadata": { "id": "_Mzk2haFC_G1" }, "source": [ "How does the architecture achieve this parallelizability?\n", "\n", "Let's start with the architecture diagram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u59eu4snLQfp" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n", "\n", "display.Image(url=base_url + \"/aiayn-figure-1.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ez-XEQ7M0UlR" }, "source": [ "> To head off a bit of confusion\n", " in case you've worked with Transformer architectures before:\n", " the original \"Transformer\" is an encoder/decoder architecture.\n", " Many LLMs, like GPT models, are decoder only,\n", " because this has turned out to scale well,\n", " and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n", " it's all text anyways.\n", " We, however, will be using them across modalities,\n", " so we need an explicit encoder,\n", " as above. " ] }, { "cell_type": "markdown", "metadata": { "id": "ok4ksBi4vp89" }, "source": [ "First focusing on the encoder (left):\n", "the encoding at a given position is a function of all previous inputs.\n", "But it is not a function of the previous _encodings_:\n", "we produce the encodings \"all at once\"." ] }, { "cell_type": "markdown", "metadata": { "id": "RPN7C-_OqzHP" }, "source": [ "The decoder (right) does use previous \"outputs\" as its inputs,\n", "but those outputs are not the vectors of layer activations\n", "(aka embeddings)\n", "that are produced by the network.\n", "They are instead the processed outputs,\n", "after a `softmax` and an `argmax`.\n", "\n", "We could obtain these outputs by processing the embeddings,\n", "much like in a recurrent architecture.\n", "In fact, that is one way that Transformers are run.\n", "It's what happens in the `.forward` method\n", "of the model we'll be training for character recognition:\n", "`ResnetTransformer`." ] }, { "cell_type": "markdown", "metadata": { "id": "L5_2WMmtDnJn" }, "source": [ "Let's look at that forward method\n", "and connect it to the diagram." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FR5pk4kEyCGg" }, "outputs": [], "source": [ "from text_recognizer.models import ResnetTransformer\n", "\n", "\n", "ResnetTransformer.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "-J5UFDoPzPbq" }, "source": [ "`.encode` happens first -- that's the left side of diagram.\n", "\n", "The encoder can in principle be anything\n", "that produces a sequence of fixed-length vectors,\n", "but here it's\n", "[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n", "\n", "Then we start iterating over the sequence\n", "in the `for` loop.\n", "\n", "Focus on the first few lines of code.\n", "We apply `.decode` (right side of diagram)\n", "to the outputs so far.\n", "\n", "Once we have a new `output`, we apply `.argmax`\n", "to turn the logits into a concrete prediction of\n", "a particular token.\n", "\n", "This is added as the last output token\n", "and then the loop happens again." ] }, { "cell_type": "markdown", "metadata": { "id": "LTcy8-rV1dHr" }, "source": [ "Run this way, our model looks very much like a recurrent architecture:\n", "we call the model on its own outputs\n", "to generate the next value.\n", "These types of models are also referred to as\n", "[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n", "because we predict (as we do in _regression_)\n", "the next value based on our own (_auto_) output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But Transformers are designed to be _trained_ more scalably than RNNs,\n", "not necessarily to _run inference_ more scalably,\n", "and it's actually not the case that our model's `.forward` is called during training." ] }, { "cell_type": "markdown", "metadata": { "id": "eCxMSAWmEKBt" }, "source": [ "Let's look at what happens during training\n", "by checking the `training_step`\n", "of the `LightningModule`\n", "we use to train our Transformer models,\n", "the `TransformerLitModel`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0o7q8N7P2w4H" }, "outputs": [], "source": [ "from text_recognizer.lit_models import TransformerLitModel\n", "\n", "TransformerLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "1VgNNOjvzC4y" }, "source": [ "Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`." ] }, { "cell_type": "markdown", "metadata": { "id": "tz-6NGPR4dUr" }, "source": [ "Let's look at `.teacher_forward`,\n", "and in particular its type signature:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ILc2oWET4i2Z" }, "outputs": [], "source": [ "TransformerLitModel.teacher_forward??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`." ] }, { "cell_type": "markdown", "metadata": { "id": "lf32lpgrDb__" }, "source": [ "This is known as \"teacher forcing\".\n", "The \"teacher\" signal is \"forcing\"\n", "the model to behave as though\n", "it got the answer right.\n", "\n", "[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n", "It's more effective here\n", "because the right teaching signal\n", "for our network is the target data,\n", "which we have access to during training,\n", "whereas in an RNN the best teaching signal\n", "would be the target embedding vector,\n", "which we do not know.\n", "\n", "During inference, when we don't have access to the ground truth,\n", "we revert to the autoregressive `.forward` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This \"trick\" allows Transformer architectures to readily scale\n", "up models to the parameter counts\n", "[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)." ] }, { "cell_type": "markdown", "metadata": { "id": "BAjqpJm9uUuU" }, "source": [ "## Is there more to Transformers more than just a training trick?" ] }, { "cell_type": "markdown", "metadata": { "id": "kWCYXeHv7Qc9" }, "source": [ "[Very](https://arxiv.org/abs/2005.14165),\n", "[very](https://arxiv.org/abs/1909.08053),\n", "[very](https://arxiv.org/abs/2205.01068)\n", "large Transformer models have powered the most recent wave of exciting results in ML, like\n", "[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n", "\n", "They are also the first machine learning models to have come anywhere close to\n", "deserving the term _artificial intelligence_ --\n", "a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is surprising because the models and their training procedure are\n", "(relatively speaking)\n", "pretty _simple_,\n", "even if it doesn't feel that way on first pass." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic Transformer architecture is just a bunch of\n", "dense matrix multiplications and non-linearities --\n", "it's perhaps simpler than a convolutional architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And advances since the introduction of Transformers in 2017\n", "have not in the main been made by\n", "creating more sophisticated model architectures\n", "but by increasing the scale of the base architecture,\n", "or if anything making it simpler, as in\n", "[GPT-type models](https://arxiv.org/abs/2005.14165),\n", "which drop the encoder." ] }, { "cell_type": "markdown", "metadata": { "id": "V1HQS9ey8GMc" }, "source": [ "These models are also trained on very simple tasks:\n", "most LLMs are just trying to predict the next element in the sequence,\n", "given the previous elements --\n", "a task simple enough that Claude Shannon,\n", "father of information theory, was\n", "[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n", "\n", "These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n", "e.g. by scraping the web." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They are also trained in a simple fashion:\n", "first-order stochastic optimizers, like SGD or an\n", "[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n", "intended for the most basic of optimization problems,\n", "that scale more readily than the second-order optimizers\n", "that dominate other areas of optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "Kz9HPDoy7OAl" }, "source": [ "This is\n", "[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n", "of work in ML:\n", "simple, even seemingly wasteful,\n", "architectures that scale well and are robust\n", "to implementation details\n", "eventually outstrip more clever but\n", "also more finicky approaches that are harder to scale.\n", "This lesson has led some to declare that\n", "[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n", "in machine learning, and perhaps even in artificial intelligence." ] }, { "cell_type": "markdown", "metadata": { "id": "SdN9o2Y771YZ" }, "source": [ "> That is not to say that because the algorithms are relatively simple,\n", " training a model at this scale is _easy_ --\n", " [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n", " [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n", " [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n", " But choosing the simplest algorithm at every step makes solving the scaling problem feasible." ] }, { "cell_type": "markdown", "metadata": { "id": "baVGf6gKFOvs" }, "source": [ "The importance of scale is the key lesson from the Transformer architecture,\n", "far more than any theoretical considerations\n", "or any of the implementation details.\n", "\n", "That said, these large Transformer models are capable of\n", "impressive behaviors and understanding how they achieve them\n", "is of intellectual interest.\n", "Furthermore, like any architecture,\n", "there are common failure modes,\n", "of the model and of the modelers who use them,\n", "that need to be taken into account." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "Below, we'll cover two key intuitions about Transformers:\n", "Transformers are _residual_, like ResNets,\n", "and they compose _low rank_ sequence transformations.\n", "Together, this means they act somewhat like a computer,\n", "reading from and writing to a \"tape\" or memory\n", "with a sequence of simple instructions." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "We'll also cover a surprising implementation detail:\n", "despite being commonly used for sequence modeling,\n", "by default the architecture is _position insensitive_." ] }, { "cell_type": "markdown", "metadata": { "id": "uni0VTCr9lev" }, "source": [ "### Intuition #1: Transformers are highly residual." ] }, { "cell_type": "markdown", "metadata": { "id": "0MoBt-JLJz-d" }, "source": [ "> The discussion of these inuitions summarizes the discussion in\n", "[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n", "from\n", "[Anthropic](https://www.anthropic.com/),\n", "an AI safety and research company.\n", "The figures below are from that blog post.\n", "It is the spiritual successor to the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "covered in\n", "[Lab 02b](https://lab02b-colab).\n", "If you want to truly understand Transformers,\n", "we highly recommend you check it out,\n", "including the\n", "[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "UUbNVvM5Ferm" }, "source": [ "It's easy to see that ResNets are residual --\n", "it's in the name, after all.\n", "\n", "But Transformers are,\n", "in some sense,\n", "even more closely tied to residual computation\n", "than are ResNets:\n", "ResNets and related architectures include downsampling,\n", "so there is not a direct path from inputs to outputs.\n", "\n", "In Transformers, the exact same shape is maintained\n", "from the moment tokens are embedded,\n", "through dozens or hundreds of intermediate layers,\n", "and until they are \"unembedded\" into class logits.\n", "The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n", "\n", "The resiudal stream is easy to see with a change of perspective.\n", "Instead of the usual architecture diagram above,\n", "which emphasizes the layers acting on the tensors,\n", "consider this alternative view,\n", "which emphasizes the tensors as they pass through the layers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HRMlVguKKW6y" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-residual-view.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9K3N7ilVkB3" }, "source": [ "For definitions of variables and terms, see the\n", "[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)." ] }, { "cell_type": "markdown", "metadata": { "id": "arvciE-kKd_L" }, "source": [ "Note that this is a _decoder-only_ Transformer architecture --\n", "so it should be compared with the right-hand side of the original architecture diagram above." ] }, { "cell_type": "markdown", "metadata": { "id": "wvrRMd_RKp_G" }, "source": [ "Notice that outputs of the attention blocks \n", "and of the MLP layers are\n", "added to their inputs, as in a ResNet.\n", "These operations are represented as \"Add & Norm\" layers in the classical diagram;\n", "normalization is ignored here for simplicity." ] }, { "cell_type": "markdown", "metadata": { "id": "o8n_iT-FFAbK" }, "source": [ "This total commitment to residual operations\n", "means the size of the embeddings\n", "(referred to as the \"model dimension\" or the \"embedding dimension\",\n", "here and below `d_model`)\n", "stays the same throughout the entire network.\n", "\n", "That means, for example,\n", "that the output of each layer can be used as input to the \"unembedding\" layer\n", "that produces logits.\n", "We can read out the computations of intermediate layers\n", "just by passing them through the unembedding layer\n", "and examining the logit tensor.\n", "See\n", "[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n", "for detailed experiments and interactive notebooks.\n", "\n", "In short, we observe a sort of \"progressive refinement\"\n", "of the next-token prediction\n", "as the embeddings proceed, depthwise, through the network." ] }, { "cell_type": "markdown", "metadata": { "id": "Ovh_3YgY9z2h" }, "source": [ "### Intuition #2 Transformer heads learn low rank transformations." ] }, { "cell_type": "markdown", "metadata": { "id": "XpNmozlnOdPC" }, "source": [ "In the original paper and in\n", "most presentations of Transformers,\n", "the attention layer is written like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PA7me8gNP5LE" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In pseudo-typed PyTorch (based loosely on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n", "that looks like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Oeict_6wGJgD" }, "source": [ "```python\n", "def classic_attention(\n", " Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " K: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " return torch.softmax(Q @ K.T) @ V\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "8pewU90DSuOR" }, "source": [ "This is effectively exactly\n", "how it is written\n", "in PyTorch,\n", "apart from implementation details\n", "(look for `bmm` for the matrix multiplications and a `softmax` call):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrgTpKFvOhwc" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "F._scaled_dot_product_attention??" ] }, { "cell_type": "markdown", "metadata": { "id": "ebDXZ0tlSe7g" }, "source": [ "But the best way to write an operation so that a computer can execute it quickly\n", "is not necessarily the best way to write it so that a human can understand it --\n", "otherwise we'd all be coding in assembly.\n", "\n", "And this is a strange way to write it --\n", "you'll notice that what we normally think of\n", "as the \"inputs\" to the layer are not shown.\n", "\n", "We can instead write out the attention layer\n", "as a function of the inputs $x$.\n", "We write it for a single \"attention head\".\n", "Each attention layer includes a number of heads\n", "that read and write from the residual stream\n", "simultaneously and independently.\n", "We also add the output layer weights $W_O$\n", "and we get:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LuFNR67tQpsf" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SVnBjjfOLwxP" }, "source": [ "or, in pseudo-typed PyTorch:" ] }, { "cell_type": "markdown", "metadata": { "id": "LmpOm-HfGaNz" }, "source": [ "```python\n", "def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n", " key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n", " key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n", " # maps queries of residual stream to keys from residual stream, independent of position\n", "\n", " value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n", " output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n", " value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n", " # transformation applied to each token, regardless of position\n", "\n", " attention_logits = x.T @ key_query_circuit @ x\n", " attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n", " # maps positions to positions, often very sparse\n", "\n", " value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n", "\n", " return attention_map @ value_output # transformed tokens filtered by attention map\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "dC0eqxZ6UAGT" }, "source": [ "Consider the `key_query_circuit`\n", "and `value_output_circuit`\n", "matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n", "\n", "The key/query dimension, `d_head`\n", "is small relative to the model's dimension, `d_model`,\n", "so $W_{QK}$ and $W_{OV}$ are very low rank,\n", "[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n", "that they factorize into two matrices,\n", "one with a smaller number of rows\n", "and another with a smaller number of columns.\n", "That number is called the _rank_.\n", "\n", "When computing, these matrices are better represented via their components,\n", "rather than computed directly,\n", "which leads to the normal implementation of attention.\n", "\n", "In a large language model,\n", "the ratio of residual stream dimension, `d_model`, to\n", "the dimension of a single head, `d_head`, is huge, often 100:1.\n", "That means each query, key, and value computed at a position\n", "is a fairly simple, low-dimensional feature of the residual stream at that position.\n", "\n", "For visual intuition,\n", "we compare what a matrix with a rank 100th of full rank looks like,\n", "relative to a full rank matrix of the same size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LUbojJMiW2C" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "\n", "low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n", "full_rank = torch.randn(100, 100)\n", "plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n", "plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");" ] }, { "cell_type": "markdown", "metadata": { "id": "lqBst92-OVka" }, "source": [ "The pattern in the first matrix is very simple,\n", "relative to the pattern in the second matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "SkCGrs9EiVh4" }, "source": [ "Another feature of low rank transformations is\n", "that they have a large nullspace or kernel --\n", "these are directions we can move the input without changing the output.\n", "\n", "That means that many changes to the residual stream won't affect the behavior of this head at all." ] }, { "cell_type": "markdown", "metadata": { "id": "UVz2dQgzhD4p" }, "source": [ "### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)." ] }, { "cell_type": "markdown", "metadata": { "id": "hVlzwR03m8mC" }, "source": [ "The combination of residuality\n", "(changes are added to the current input)\n", "and low rank\n", "(only a small subspace is changed by each head)\n", "drastically changes the intuition about Transformers." ] }, { "cell_type": "markdown", "metadata": { "id": "qqjZI2jKe6HH" }, "source": [ "Rather than being an \"embedding of a token in its context\",\n", "the residual stream becomes something more like a memory or a scratchpad:\n", "one layer reads a small bit of information from the stream\n", "and writes a small bit of information back to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YIBkxlqepjc" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-layer-residual.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtsKhkLfk00l" }, "source": [ "The residual stream works like a memory because it is roomy enough\n", "that these actions need not interfere:\n", "the subspaces targeted by reads and writes are small relative to the ambient space,\n", "so they can\n", "\n", "Additionally, the dimension of each head is still in the 100s in large models,\n", "and\n", "[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n", "in them, so the number of effectively degrees of freedom is\n", "actually larger than the dimension.\n", "This phenomenon allows high-dimensional tensors to serve as\n", "[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n", "There are\n", "[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n", "\n", "Together, this means an early layer can write information to the stream\n", "that can be used by later layers -- by many of them at once, possibly much later.\n", "Later layers can learn to edit this information,\n", "e.g. deleting it,\n", "if doing so reduces the loss,\n", "but by default the information is preserved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EragIygzJg86" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-stream-read-write.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "oKIaUZjwkpW7" }, "source": [ "Lastly, the softmax in the attention has a sparsifying effect,\n", "and so many attention heads are reading from \n", "just one token and writing to just one other token." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dN6VcJqIMKnB" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-token-to-token.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Repeatedly reading information from an external memory\n", "and using it to decide which operation to perform\n", "and where to write the results\n", "is at the core of the\n", "[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n", "For a concrete example, the\n", "[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n", "includes a dissection of a form of \"pointer arithmetic\"\n", "that appears in some models." ] }, { "cell_type": "markdown", "metadata": { "id": "0kLFh7Mvnolr" }, "source": [ "This point of view seems\n", "very promising for explaining numerous\n", "otherwise perhaps counterintuitive features of Transformer models.\n", "\n", "- This framework predicts lots that Transformers will readily copy-and-paste information,\n", "which might explain phenomena like\n", "[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n", "\n", "- It also readily explains\n", "[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n", "an important component of why Transformers perform well on medium-length texts\n", "and in few-shot learning.\n", "\n", "- Transformers also perform better on reasoning tasks when the text\n", "[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n", "is added to their input prompt.\n", "This is partly due to the fact that that prompt is associated,\n", "in the dataset, with clearer reasoning,\n", "and since the models are trained to predict which tokens tend to appear\n", "after an input, they tend to produce better reasoning with that prompt --\n", "an explanation purely in terms of sequence modeling.\n", "But it also gives the Transformer license to generate a large number of tokens\n", "that act to store intermediate information,\n", "making for a richer residual stream\n", "for reading and writing." ] }, { "cell_type": "markdown", "metadata": { "id": "RyLRzgG-93yB" }, "source": [ "### Implementation detail: Transformers are position-insensitive by default." ] }, { "cell_type": "markdown", "metadata": { "id": "oR6PnrlA_hJ2" }, "source": [ "In the attention calculation\n", "each token can query each other token,\n", "with no regard for order.\n", "Furthermore, the construction of queries, keys, and values\n", "is based on the content of the embedding vector,\n", "which does not automatically include its position.\n", "\"dog bites man\" and \"man bites dog\" are identical, as in\n", "[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n", "\n", "For most sequences,\n", "this is unacceptable:\n", "absolute and relative position matter\n", "and we cannot use the future to predict the past.\n", "\n", "We need to add two pieces to get a Transformer architecture that's usable for next-token prediction." ] }, { "cell_type": "markdown", "metadata": { "id": "EWHxGJz2-6ZK" }, "source": [ "First, the simpler piece:\n", "\"causal\" attention,\n", "so-named because it ensures that values earlier in the sequence\n", "are not influenced by later values, which would\n", "[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)." ] }, { "cell_type": "markdown", "metadata": { "id": "0c42xi6URYB4" }, "source": [ "The most common solution is straightforward:\n", "we calculate attention between all tokens,\n", "then throw out non-causal values by \"masking\" them\n", "(this is before applying the softmax,\n", "so masking means adding $-\\infty$).\n", "\n", "This feels wasteful --\n", "why are we calculating values we don't need?\n", "Trying to be smarter would be harder,\n", "and might rely on operations that aren't as optimized as\n", "matrix multiplication and addition.\n", "Furthermore, it's \"only\" twice as many operations,\n", "so it doesn't even show up in $O$-notation.\n", "\n", "A sample attention mask generated by our code base is shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NXaWe6pT-9jV" }, "outputs": [], "source": [ "from text_recognizer.models import transformer_util\n", "\n", "\n", "attention_mask = transformer_util.generate_square_subsequent_mask(100)\n", "\n", "ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n", "plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n", "print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This solves our causality problem,\n", "but we still don't have positional information." ] }, { "cell_type": "markdown", "metadata": { "id": "ZamUE4WIoGS2" }, "source": [ "The standard technique\n", "is to add alternating sines and cosines\n", "of increasing frequency to the embeddings\n", "(there are\n", "[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n", "most notably\n", "[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n", "Each position in the sequence is then uniquely identifiable\n", "from the pattern of these values.\n", "\n", "> Furthermore, for the same reason that\n", " [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n", " translations, e.g. relative positions, are fairly easy to express as linear transformations\n", " of sines and cosines)." ] }, { "cell_type": "markdown", "metadata": { "id": "IDG2uOsaELU0" }, "source": [ "We superimpose this positional information on our embeddings.\n", "Note that because the model is residual,\n", "this position information will be by default preserved\n", "as it passes through the network,\n", "so it doesn't need to be repeatedly added." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what this positional encoding looks like in our codebase:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Zk62Q-a-1Ax" }, "outputs": [], "source": [ "PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n", "\n", "pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n", "\n", "ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n", "print(pe[:4, :8])" ] }, { "cell_type": "markdown", "metadata": { "id": "ep2ClIWvqDms" }, "source": [ "When we add the positional information to our embeddings,\n", "both the embedding information and the positional information\n", "is approximately preserved,\n", "as can be visually assessed below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PJuFjoCzC0Y4" }, "outputs": [], "source": [ "fake_embeddings = torch.randn_like(pe) * 0.5\n", "\n", "ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n", "\n", "fake_embeddings_with_pe = fake_embeddings + pe\n", "\n", "plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIzBxDkEmH8" }, "source": [ "A [similar technique](https://arxiv.org/abs/2103.06450)\n", "is used to also incorporate positional information into the image embeddings,\n", "which are flattened before being fed to the decoder." ] }, { "cell_type": "markdown", "metadata": { "id": "HC1N85wl8dvn" }, "source": [ "### Learn more about Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "lJwYxkjTk15t" }, "source": [ "We're only able to give a flavor and an intuition for Transformers here.\n", "\n", "To improve your grasp on the nuts and bolts, check out the\n", "[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n", "which is surprisingly approachable,\n", "as far as ML research papers go.\n", "The\n", "[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n", "adds code and commentary to the original paper,\n", "which makes it even more digestible.\n", "For something even friendlier, check out the\n", "[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n", "by Jay Alammar, which has an accompanying\n", "[video](https://youtu.be/-QH8fRhqFHM).\n", "\n", "Anthropic's work on\n", "[Transformer Circuits](https://transformer-circuits.pub/),\n", "summarized above, has some of the best material\n", "for building theoretical understanding\n", "and is still being updated with extensions and applications of the framework.\n", "The\n", "[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n", "are a great aid for checking and building your understanding.\n", "\n", "But they are fairly math-heavy.\n", "If you have more of a software engineering background, see\n", "Transformer Circuits co-author Nelson Elhage's blog post\n", "[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n", "\n", "For a gentler introduction to the intuition for Transformers,\n", "check out Brandon Rohrer's\n", "[Transformers From Scratch](https://e2eml.school/transformers.html)\n", "tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "qg7zntJES-aT" }, "source": [ "An aside:\n", "the matrix multiplications inside attention dominate\n", "the big-$O$ runtime of Transformers.\n", "So trying to make the attention mechanism more efficient, e.g. linear time,\n", "has generated a lot of research\n", "(review paper\n", "[here](https://arxiv.org/abs/2009.06732)).\n", "Despite drawing a lot of attention, so to speak,\n", "at the time of writing in mid-2022, these methods\n", "[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n", "so it isn't likely to be worth the effort to spend time learning about them\n", "unless you are a Transformer specialist." ] }, { "cell_type": "markdown", "metadata": { "id": "vCjXysEJ8g9_" }, "source": [ "# Using Transformers to read paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "KsfKWnOvqjva" }, "source": [ "Our simple convolutional model for text recognition from\n", "[Lab 02b](https://fsdl.me/lab02b-colab)\n", "could only handle cleanly-separated characters.\n", "\n", "It worked by sliding a LeNet-style CNN\n", "over the image,\n", "predicting a character for each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njLdzBqy-I90" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist_lines = text_recognizer.data.EMNISTLines()\n", "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "\n", "# for sliding, see the for loop over range(S)\n", "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "K0N6yDBQq8ns" }, "source": [ "But unfortunately for us, handwritten text\n", "doesn't come in neatly-separated characters\n", "of equal size, so we trained our model on synthetic data\n", "designed to work with that model." ] }, { "cell_type": "markdown", "metadata": { "id": "hiqUVbj0sxLr" }, "source": [ "Now that we have a better model,\n", "we can work with better data:\n", "paragraphs from the\n", "[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)." ] }, { "cell_type": "markdown", "metadata": { "id": "oizsOAcKs-dD" }, "source": [ "The cell uses our `LightningDataModule`\n", "to download and preprocess this data,\n", "writing results to disk.\n", "We can then spin up `DataLoader`s to give us batches.\n", "\n", "It can take several minutes to run the first time\n", "on commodity machines,\n", "with most time spent extracting the data.\n", "On subsequent runs,\n", "the time-consuming operations will not be repeated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uL9LHbjdsUbm" }, "outputs": [], "source": [ "iam_paragraphs = text_recognizer.data.IAMParagraphs()\n", "\n", "iam_paragraphs.prepare_data()\n", "iam_paragraphs.setup()\n", "xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n", "\n", "iam_paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "nBkFN9bbTm_S" }, "source": [ "Now that we've got a batch,\n", "let's take a look at some samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hqaps8yxtBhU" }, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import wandb\n", "\n", "\n", "def show(y):\n", " y = y.detach().cpu() # bring back from accelerator if it's being used\n", " return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"

\", \"\")\n", "\n", "idx = random.randint(0, len(xs))\n", "\n", "print(show(ys[idx]))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "4dT3UCNzTsoc" }, "source": [ "The `ResnetTransformer` model can run on this data\n", "if passed the `.config`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXL-vIGRr86D" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())" ] }, { "cell_type": "markdown", "metadata": { "id": "MMxa-oWyT01E" }, "source": [ "Our models are now big enough\n", "that we want to make use of GPU acceleration\n", "as much as we can,\n", "even when working on single inputs,\n", "so let's cast to the GPU if we have one." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YyUM8LgvW0w" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "rnt.to(device); xs = xs.to(device); ys = ys.to(device);" ] }, { "cell_type": "markdown", "metadata": { "id": "Y-E3UdD4zUJi" }, "source": [ "First, let's just pass it through the ResNet encoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-LUUtlvaxrvg" }, "outputs": [], "source": [ "resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n", " # resnet is designed for RGB images, so we replicate the input across channels 3 times" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eimgJ5dnywjg" }, "outputs": [], "source": [ "resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n", "plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n", "plt.axis(\"off\"); plt.colorbar(fraction=0.05);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These embeddings, though generated by random, untrained weights,\n", "are not entirely useless.\n", "\n", "Before neural networks could be effectively\n", "trained end to end,\n", "they were often used with frozen random weights\n", "eveywhere except the final layer\n", "(see e.g.\n", "[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n", "[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n", "these methods were still competitive, and\n", "[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n", "provide a\n", "[theoretical basis](https://arxiv.org/abs/2011.14522)\n", "for understanding their performance." ] }, { "cell_type": "markdown", "metadata": { "id": "ye6pW0ETzw2A" }, "source": [ "The final result, though, is repetitive gibberish --\n", "at the bare minimum, we need to train the unembedding/readout layer\n", "in order to get reasonable text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our architecture includes randomization with dropout,\n", "so repeated runs of the cell below will generate different outcomes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu3Pa7gLsFMo" }, "outputs": [], "source": [ "preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gvCXUbskv6XM" }, "outputs": [], "source": [ "print(show(preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without teacher forcing, runtime is also variable from iteration to iteration --\n", "the model stops when it generates an \"end sequence\" or padding token,\n", "which is not deterministic thanks to the dropout layers.\n", "For similar reasons, runtime is variable across inputs.\n", "\n", "The variable runtime of autoregressive generation\n", "is also not great for scaling.\n", "In a distributed setting, as required for large scale,\n", "forward passes need to be synced across devices,\n", "and if one device is generating a batch of much longer sequences,\n", "it will cause all the others to idle while they wait on it to finish." ] }, { "cell_type": "markdown", "metadata": { "id": "t76MSVRXV0V7" }, "source": [ "Let's turn our model into a `TransformerLitModel`\n", "so we can run with teacher forcing.\n", "\n", "> You may be wondering:\n", " why isn't teacher forcing part of the PyTorch module?\n", " In general, the `LightningModule`\n", " should encapsulate things that are needed in training, validation, and testing\n", " but not during inference.\n", " The teacher forcing trick fits this paradigm,\n", " even though it's so critical to what makes Transformers powerful. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8qrHRKHowdDi" }, "outputs": [], "source": [ "import text_recognizer.lit_models\n", "\n", "lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlNaFqR50Oid" }, "source": [ "Now we can use `.teacher_forward` if we also provide the target `ys`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpZdqXS5wn0F" }, "outputs": [], "source": [ "forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Zx9SmsN0QLT" }, "source": [ "This may not run faster than the `rnt.forward`,\n", "since generations are always the maximum possible length,\n", "but runtimes and output lengths are deterministic and constant." ] }, { "cell_type": "markdown", "metadata": { "id": "tu-XNYpi0Qvi" }, "source": [ "Forcing doesn't necessarily make our predictions better.\n", "They remain highly repetitive gibberish." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JcEgify9w0sv" }, "outputs": [], "source": [ "forcing_preds = torch.argmax(forcing_outs, dim=0)\n", "\n", "print(show(forcing_preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xn6GGNzc9a3o" }, "source": [ "## Training the `ResNetTransformer`" ] }, { "cell_type": "markdown", "metadata": { "id": "uvZYsuSyWUXe" }, "source": [ "We're finally ready to train this model on full paragraphs of handwritten text!" ] }, { "cell_type": "markdown", "metadata": { "id": "3cJwC7b720Sd" }, "source": [ "This is a more serious model --\n", "it's the one we use in the\n", "[deployed TextRecognizer application](http://fsdl.me/app).\n", "It's much larger than the models we've seen this far,\n", "so it can easily outstrip available compute resources,\n", "in particular GPU memory.\n", "\n", "To help, we use\n", "[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n", "which shrinks the size of most of our floats by half,\n", "which reduces memory consumption and can speed up computation.\n", "\n", "If your GPU has less than 8GB of available RAM,\n", "you'll see a \"CUDA out of memory\" `RuntimeError`,\n", "which is something of a\n", "[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n", "In this case, you can resolve it by reducing the `--batch_size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w1mXlhfy04Nm" }, "outputs": [], "source": [ "import torch\n", "\n", "gpus = int(torch.cuda.is_available())\n", "\n", "if gpus:\n", " !nvidia-smi\n", "else:\n", " print(\"watch out! working with this model on a typical CPU is not feasible\")" ] }, { "cell_type": "markdown", "metadata": { "id": "os1vW1rPZ1dy" }, "source": [ "Even with an okay GPU, like a\n", "[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n", "a single epoch of training can take over 10 minutes to run.\n", "We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n", "but you can remove those flags to see what full training looks like." ] }, { "cell_type": "markdown", "metadata": { "id": "vnF6dWFn4JlZ" }, "source": [ "It can take a long time (overnight)\n", "to train this model to decent performance on a single GPU,\n", "so we'll focus on other pieces for the exercises.\n", "\n", "> At the time of writing in mid-2022, the cheapest readily available option\n", "for training this model to decent performance on this dataset with this codebase\n", "comes out around $10, using\n", "[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n", "See, for example,\n", "[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n", "and associated experiment.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HufjdUZN0t4l", "scrolled": false }, "outputs": [], "source": [ "%%time\n", "# above %%magic times the cell, useful as a poor man's profiler\n", "\n", "%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n", " --gpus={gpus} --batch_size 16 --precision 16 \\\n", " --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2" ] }, { "cell_type": "markdown", "metadata": { "id": "L6fQ93ju3Iku" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "udb1Ekjx3L63" }, "source": [ "### 🌟 Try out gradient accumulation and other \"training tricks\"." ] }, { "cell_type": "markdown", "metadata": { "id": "kpqViB4p3Wfb" }, "source": [ "Larger batches are helpful not only for increasing parallelization\n", "and amortizing fixed costs\n", "but also for getting more reliable gradients.\n", "Larger batches give gradients with less noise\n", "and to a point, less gradient noise means faster convergence.\n", "\n", "But larger batches result in larger tensors,\n", "which take up more GPU memory,\n", "a resource that is tightly constrained\n", "and device-dependent.\n", "\n", "Does that mean we are limited in the quality of our gradients\n", "due to our machine size?\n", "\n", "Not entirely:\n", "look up the `--accumulate_grad_batches`\n", "argument to the `pl.Trainer`.\n", "You should be able to understand why\n", "it makes it possible to compute the same gradients\n", "you would find for a batch of size `k * N`\n", "on a machine that can only run batches up to size `N`.\n", "\n", "Accumulating gradients across batches is among the\n", "[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n", "Try some of them out!\n", "Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment." ] }, { "cell_type": "markdown", "metadata": { "id": "b2vtkmX830y3" }, "source": [ "### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n", "\n", "While training this model to actually fit the whole dataset is infeasible\n", "as a short exercise on commodity hardware,\n", "it's practical to train this model to memorize a batch of 16 examples.\n", "\n", "Passing `--overfit_batches 1` flag limits the number of training batches to 1\n", "and turns off\n", "[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n", "so that in each epoch, the model just sees the same single batch of data over and over again.\n", "\n", "At first, try training the model to a loss of `2.5` --\n", "it should be doable in 100 epochs or less,\n", "which is just a few minutes on a commodity GPU.\n", "\n", "Once you've got that working,\n", "crank up the number of epochs by a factor of 10\n", "and confirm that the loss continues to go down.\n", "\n", "Some tips:\n", "\n", "- Use `--limit_test_batches 0` to turn off testing.\n", "We don't need it because we don't care about generalization\n", "and it's relatively slow because it runs the model autoregressively.\n", "\n", "- Use `--help` and look through the model class args\n", "to find the arguments used to reduce model size.\n", "\n", "- By default, there's lots of regularization to prevent overfitting.\n", "Look through the args for the model class and data class\n", "for regularization knobs to turn off or down." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab03_transformers.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab06/notebooks/lab04_experiments.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 04: Experiment Management" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How experiment management brings observability to ML model development\n", "- Which features of experiment management we use in developing the Text Recognizer\n", "- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 4\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This lab contains a large number of embedded iframes\n", "that benefit from having a wide window.\n", "The cell below makes the notebook as wide as your browser window\n", "if `full_width` is set to `True`.\n", "Full width is the default behavior in Colab,\n", "so this cell is intended to improve the viewing experience in other Jupyter environments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Why experiment management?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand why we need experiment management for ML development,\n", "let's start by running an experiment.\n", "\n", "We'll train a new model on a new dataset,\n", "using the training script `training/run_experiment.py`\n", "introduced in [Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a CNN encoder and Transformer decoder, as in\n", "[Lab 03](https://fsdl.me/lab03-colab),\n", "but with some changes so we can iterate faster.\n", "We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n", "[Lab02b](https://fsdl.me/lab02b-colab),\n", "and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n", "from text_recognizer.data import IAMLines # processed version split into individual lines\n", "from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n", "\n", "\n", "print(IAM.__doc__)\n", "\n", "# uncomment a line below for details on either class\n", "# IAMLines?? \n", "# LineCNNTransformer??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below will train a model on 10% of the data for two epochs.\n", "\n", "It takes up to a few minutes to run on commodity hardware,\n", "including data download and preprocessing.\n", "As it's running, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%%time\n", "import torch\n", "\n", "\n", "gpus = int(torch.cuda.is_available()) \n", "\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the model trains, we're calculating lots of metrics --\n", "loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n", "and reporting them to the terminal.\n", "\n", "This is achieved by the built-in `.log` method\n", "([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n", "of the `LightningModule`,\n", "and it is a very straightforward way to get basic information about your experiment as it's running\n", "without leaving the context where you're running it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning to read\n", "[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n", "is something of a rite of passage for MLEs, but\n", "let's consider what we can't see here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We're missing all metric values except the most recent --\n", "we can see them as they stream in, but they're constantly overwritten.\n", "We also can't associate them with timestamps, steps, or epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We also don't see any system metrics.\n", "We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n", "without launching a separate process.\n", "And even if we do, those values will also not be saved and timestamped,\n", "so we can't correlate them with other things during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- As we continue to run experiments, changing code and opening new terminals,\n", "even the information we have or could figure out now will disappear.\n", "Say you spot a weird error message during training,\n", "but your session ends and the stdout is gone,\n", "so you don't know exactly what it was.\n", "Can you recreate the error?\n", "Which git branch and commit were you on?\n", "Did you have any uncommitted changes? Which arguments did you pass?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Also, model checkpoints containing the parameter values have been saved to disk.\n", "Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n", "As we run more and more experiments,\n", "we'll want to slice and dice them to see if,\n", "say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to save and log all of this information, and more, in order to make our model training\n", "[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n", "in short, so that we can understand, make decisions about, and debug our model training\n", "by looking at logs and source code, without having to recreate it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n", "1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n", "2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n", "3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n", "4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Local Experiment Tracking with Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n", "and it makes logging very easy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n", "\n", "It can also use a logger to push information elsewhere.\n", "\n", "By default, we use\n", "[TensorBoard](https://www.tensorflow.org/tensorboard)\n", "via the Lightning `TensorBoardLogger`,\n", "which has been saving results to the local disk.\n", "\n", "Let's find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest experiment's directory\n", "# by hand, you can just copy and paste it from the terminal\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n", "filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n", "latest_log" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!ls -lh {latest_log}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, we need to launch a TensorBoard server --\n", "much like we need to launch a Jupyter server to use Jupyter notebooks.\n", "\n", "The cells below load an extension that lets you use TensorBoard inside of a notebook\n", "the same way you'd use it from the command line, and then launch it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext tensorboard" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n", "\n", "port = 11717 # pick an open port on your machine\n", "host = \"0.0.0.0\" # allow connections from the internet\n", " # watch out! make sure you turn TensorBoard off\n", "\n", "%tensorboard --logdir {latest_log} --port {port} --host {host}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should see some charts of metrics over time along with some charting controls.\n", "\n", "You can click around in this interface and explore it if you'd like,\n", "but in the next section, we'll see that there are better tools for experiment management." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've run many experiments on this machine,\n", "you can see all of their results by pointing TensorBoard\n", "at the whole `lightning_logs` directory,\n", "rather than just one experiment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For large numbers of experiments, the management experience is not great --\n", "it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n", "\n", "It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n", "which are important workflows as applications mature and teams grow." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n", "\n", "If you run into any issues with the above cells failing to launch,\n", "especially across iterations of this lab, run this cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorboard.manager\n", "\n", "# get the process IDs for all tensorboard instances\n", "pids = [tb.pid for tb in tensorboard.manager.get_all()]\n", "\n", "done_with_tensorboard = False\n", "\n", "if done_with_tensorboard:\n", " # kill processes\n", " for pid in pids:\n", " !kill {pid} 2> /dev/null\n", " \n", " # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n", " !rm -rf {tensorboard.manager._get_info_dir()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment Management with Weights & Biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do we manage experiments when we hit the limits of local TensorBoard?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorBoard is powerful and flexible and very scalable,\n", "but running it requires engineering effort and babysitting --\n", "you're running a database, writing data to it,\n", "and layering a web application over it.\n", "\n", "This is a fairly common workflow for web developers,\n", "but not so much for ML engineers.\n", "\n", "You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n", "and it's as simple as running the command `tensorboard dev upload`\n", "pointed at your logging directory.\n", "\n", "But there are strict limits to this free service:\n", "1GB of tensor data and 1GB of binary data.\n", "A single Text Recognizer model checkpoint is ~100MB,\n", "and that's not particularly large for a useful model.\n", "\n", "Furthermore, all data is public,\n", "so if you upload the inputs and outputs of your model,\n", "anyone who finds the link can see them.\n", "\n", "Overall, tensorboard.dev works very well for certain academic and open projects\n", "but not for industrial ML." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To avoid that narrow permissions and limits issue,\n", "you could use [git LFS](https://git-lfs.github.com/)\n", "to track the binary data and tensor data,\n", "which is more likely to be sensitive than metrics.\n", "\n", "The Hugging Face ecosystem uses TensorBoard and git LFS.\n", "\n", "It includes the Hugging Face Hub, a git server much like GitHub,\n", "but designed first and foremost for collaboration on models and datasets,\n", "rather than collaboration on code.\n", "For example, the Hugging Face Hub\n", "[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n", "and officially has\n", "[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n", "avoiding the\n", "[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n", "that make using git LFS with GitHub expensive.\n", "\n", "However, we prefer to avoid mixing software version control and experiment management.\n", "\n", "First, using the Hub requires maintaining an additional git remote,\n", "which is a hard ask for many engineering teams.\n", "\n", "Secondly, git-style versioning is an awkward fit for logging --\n", "is it really sensible to create a new commit for each logging event while you're watching live?\n", "\n", "Instead, we prefer to use systems that solve experiment management with _databases_." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n", "The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n", "tool is [MLflow](https://github.com/mlflow/mlflow/)\n", "and there are a number of\n", "[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n", "\n", "These tools generally avoid any need to worry about hosting\n", "(unless data governance rules require a self-hosted version).\n", "\n", "For a sampling of publicly-posted opinions on experiment management tools,\n", "see these discussions from Reddit:\n", "\n", "- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n", "- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n", "\n", "Among these tools, the FSDL recommendation is\n", "[Weights & Biases](https://wandb.ai),\n", "which we believe offers\n", "- the best user experience, both in the Python SDKs and in the graphical interface\n", "- the best integrations with other tools,\n", "including\n", "[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n", "[Keras](https://docs.wandb.ai/guides/integrations/keras),\n", "[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n", "and even\n", "[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n", "and\n", "- the best tools for collaboration.\n", "\n", "Below, we'll take care to point out which logging and management features\n", "are available via generic interfaces in Lightning and which are W&B-specific." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "print(wandb.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding it to our experiment running code is extremely easy,\n", "relative to the features we get, which is\n", "one of the main selling points of W&B.\n", "\n", "We get most of our new experiment management features just by changing a single variable, `logger`, from\n", "`TensorboardLogger` to `WandbLogger`\n", "and adding two lines of code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see what each of these lines does for us below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this logger is built into and maintained by PyTorch Lightning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning.loggers import WandbLogger\n", "\n", "\n", "WandbLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to complete the rest of this notebook,\n", "you'll need a Weights & Biases account.\n", "\n", "As with GitHub the free tier, for personal, academic, and open source work,\n", "is very generous.\n", "\n", "The Text Recognizer project will fit comfortably within the free tier.\n", "\n", "Run the cell below and follow the prompts to log in or create an account or go\n", "[here](https://wandb.ai/signup)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb login" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the cell below to launch an experiment tracked with Weights & Biases.\n", "\n", "The experiment can take between 3 and 10 minutes to run.\n", "In that time, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n", " --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1\n", " \n", "last_expt = wandb.run\n", "\n", "wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see some new things in our output.\n", "\n", "For example, there's a note from `wandb` that the data is saved locally\n", "and also synced to their servers.\n", "\n", "There's a link to a webpage for viewing the logged data and a name for our experiment --\n", "something like `dandy-sunset-1`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local logging and cloud syncing happens with minimal impact on performance,\n", "because `wandb` launches a separate process to listen for events and upload them.\n", "\n", "That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Runs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, head to the link in the notebook output\n", "that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n", "\n", "There's no need to wait for training to finish.\n", "\n", "The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(last_expt.url)\n", "IFrame(last_expt.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have landed on the run page\n", "([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n", "which collects up all of the information for a single experiment into a collection of tabs.\n", "\n", "We'll work through these tabs from top to bottom.\n", "\n", "Each header is also a link to the documentation for a tab." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n", "This tab has an icon that looks like `(i)` or 🛈.\n", "\n", "The top section of this tab has high-level information about our run:\n", "- Timing information, like start time and duration\n", "- System hardware, hostname, and basic environment info\n", "- Git repository link and state\n", "\n", "This information is collected and logged automatically.\n", "\n", "The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n", "and summary metrics.\n", "\n", "Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n", "\n", "This tab has a line plot icon, something like 📈.\n", "\n", "It's also the default page you land on when looking at a W&B run.\n", "\n", "Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n", "\n", "Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n", "\n", "We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n", "This tab has computer chip icon.\n", "\n", "It contains\n", "- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n", "- CPU metrics: memory usage, utilization, thread counts\n", "- Disk and network I/O levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n", "This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n", "\n", "The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n", "This tab has an icon that looks like a stylized command prompt, `>_`.\n", "\n", "It contains information that was printed to the stdout.\n", "\n", "This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n", "\n", "Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelSummary\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n", "\n", "For more on Lightning `Callback`s, see\n", "[Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n", "This tab has a stylized document icon, something like 📄.\n", "\n", "You can use this tab to view any files saved with the `wandb.save`.\n", "\n", "For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n", "which we'll discuss shortly.\n", "\n", "But a few pieces of information automatically collected by W&B end up in this tab.\n", "\n", "Some highlights:\n", " - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n", " - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n", "This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n", "\n", "This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n", "\n", "We store two kinds of binary files\n", " - `run_table`s of model inputs and outputs\n", " - `model` checkpoints\n", "\n", "We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive Tables of Logged Media" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returning to the Charts tab,\n", "notice that we have model inputs and outputs logged in structured tables\n", "under the train, validation, and test sections.\n", "\n", "These tables are interactive as well\n", "([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n", "They support basic exploratory data analysis and are compatible with W&B's collaboration features." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to charts in our run page, these tables also have their own pages inside the W&B web app." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n", "table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n", "\n", "print(table_data_url)\n", "IFrame(src=table_data_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting this to work requires more effort and more W&B-specific code\n", "than the other features we've seen so far.\n", "\n", "We'll briefly explain the implementation here, for those who are interested.\n", "\n", "We use a custom Lightning `Callback`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n", "\n", "\n", "ImageToTextTableLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n", "\n", "The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n", "\n", "This behavior is sensible for metrics, which are low overhead, but not so much for media,\n", "where we'd rather subsample and avoid holding on to too much information.\n", "\n", "So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n", "\n", "The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.lit_models.base import BaseImageToTextLitModel\n", "\n", "BaseImageToTextLitModel.add_on_logged_batches??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Projects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Everything we've seen so far has been related to a single run or experiment.\n", "\n", "Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n", "\n", "We organize our runs into \"projects\" and view them on the W&B \"project page\" \n", "([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n", "\n", "By default in the Lightning integration, the project name is determined based on directory information.\n", "This default can be over-ridden in the code when creating a `WandbLogger`,\n", "but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what the project page looks like for a longer-running project with lots of experiments.\n", "\n", "The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n", "\n", "print(project_url)\n", "IFrame(src=project_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n", "\n", "We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Artifacts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Beyond logging metrics and metadata from runs,\n", "we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below pulls up all of the artifacts associated with the experiment we just ran." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Click on one of the `model` checkpoints -- the specific version doesn't matter.\n", "\n", "There are a number of tabs here.\n", "\n", "The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n", "\n", "The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n", "which are added by default by the `WandbLogger`.\n", "\n", "The \"Files\" tab contains the actual file contents of the artifact.\n", "\n", "On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n", "including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n", "\n", "You can click on these to explore the different versions and even directly compare them.\n", "\n", "If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Artifact storage is part of the W&B free tier.\n", "\n", "The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n", "\n", "The former is sufficient to store ~700 model checkpoints for the Text Recognizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can track your data storage and compare it to your limits at this URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n", "\n", "print(storage_tracker_url)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Programmatic Access" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also programmatically access our data and metadata via the `wandb` API\n", "([docs](https://docs.wandb.ai/guides/track/public-api-guide)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wb_api = wandb.Api()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run = wb_api.run(\"/\".join( # fetch a run given\n", " [last_expt.entity, # the user or org it was logged to\n", " last_expt.project, # the \"project\", usually one of several per repo/application\n", " last_expt.id] # and a unique ID\n", "))\n", "\n", "hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n", "\n", "hist.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hist.groupby(\"epoch\")[\"train/loss\"].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this includes the artifacts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# which artifacts where created and logged?\n", "artifacts = run.logged_artifacts()\n", "\n", "for artifact in artifacts:\n", " print(f\"artifact of type {artifact.type}: {artifact.name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thanks to our `ImageToTextTableLogger`,\n", "we can easily recreate training or validation data that came out of our `DataLoader`s,\n", "which is normally ephemeral:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n", "artifact_dir = Path(artifact.download(root=\"training/logs\"))\n", "image_dir = artifact_dir / \"media\" / \"images\"\n", "\n", "images = [path for path in image_dir.iterdir()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "from IPython.display import Image\n", "\n", "Image(str(random.choice(images)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Advanced W&B API Usage: MLOps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the strengths of a well-instrumented experiment tracking system is that it allows\n", "automatic relation of information:\n", "what were the inputs when this model's gradient spiked?\n", "Which models have been trained on this dataset,\n", "and what was their performance?\n", "\n", "Having access and automation around this information is necessary for \"MLOps\",\n", "which applies contemporary DevOps principles to ML projects." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cells below pull down the training data\n", "for the model currently running the FSDL Text Recognizer app.\n", "\n", "This is just intended as a demonstration of what's possible,\n", "so don't worry about understanding every piece of this,\n", "and feel free to skip past it.\n", "\n", "MLOps is still a nascent field, and these tools and workflows are likely to change.\n", "\n", "For example, just before the course launched, W&B released a\n", "[Model Registry layer](https://docs.wandb.ai/guides/models)\n", "on top of artifact logging that aims to improve the developer experience for these workflows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start from the same project we looked at in the project view:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n", "\n", "text_recognizer_project " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we search it for the text recognizer model currently being used in production:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# collect all versions of the text-recognizer ever put into production by...\n", "\n", "for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n", " if art_type.name == \"prod-ready\": # for the prod-ready type\n", " # and grabbing the text-recognizer\n", " production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n", "\n", "# and then get the one that's currently being tested in CI by...\n", "for text_recognizer in production_text_recognizers:\n", " if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n", " in_prod_text_recognizer = text_recognizer\n", "\n", "# view its metadata at the url or in the notebook\n", "in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n", "\n", "print(in_prod_text_recognizer_url)\n", "IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From its metadata, we can get information about how it was \"staged\" to be put into production,\n", "and in particular which model checkpoint was used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "staging_run = in_prod_text_recognizer.logged_by()\n", "\n", "training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n", "training_ckpt.name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That checkpoint was logged by a training experiment, which is available as metadata.\n", "\n", "We can look at the training run for that model, either here in the notebook or at its URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "training_run = training_ckpt.logged_by()\n", "print(training_run.url)\n", "IFrame(src=training_run.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And from there, we can access logs and metadata about training,\n", "confident that we are working with the model that is actually in production.\n", "\n", "For example, we can pull down the data we logged and analyze it locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_results = training_run.history(samples=10000)\n", "training_results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n", "training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 10\n", "training_results[\"validation/loss\"].dropna().iloc[10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The charts and webpages in Weights & Biases\n", "are substantially more useful than ephemeral stdouts or raw logs on disk.\n", "\n", "If you're spun up on the project,\n", "they accelerate debugging, exploration, and discovery.\n", "\n", "If not, they're not so much useful as they are overwhelming.\n", "\n", "We need to synthesize the raw logged data into information.\n", "This helps us communicate our work with other stakeholders,\n", "preserve knowledge and prevent repetition of work,\n", "and surface insights faster.\n", "\n", "These workflows are supported by the W&B Reports feature\n", "([docs here](https://docs.wandb.ai/guides/reports)),\n", "which mix W&B charts and tables with explanatory markdown text and embeds.\n", "\n", "Below are some common report patterns and\n", "use cases and examples of each." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the examples are from the FSDL Text Recognizer project.\n", "You can find more of them\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n", "where we've organized them into a report!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dashboard Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dashboards are a structured subset of the output from one or more experiments,\n", "designed for quickly surfacing issues or insights,\n", "like an accuracy or performance regression\n", "or a change in the data distribution.\n", "\n", "Use cases:\n", "- show the basic state of ongoing experiment\n", "- compare one experiment to another\n", "- select the most important charts so you can spin back up into context on a project more quickly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n", "\n", "IFrame(src=dashboard_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull Request Documentation Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In most software codebases,\n", "pull requests are a key focal point\n", "for units of work that combine\n", "short-term communication and long-term information tracking.\n", "\n", "In ML codebases, it's more difficult to bring\n", "sufficient information together to make PRs as useful.\n", "At FSDL, we like to add documentary\n", "reports with one or a small number of charts\n", "that connect logged information in the experiment management system\n", "to state in the version control software.\n", "\n", "Use cases:\n", "- communication of results within a team, e.g. code review\n", "- record-keeping that links pull request pages to raw logged info and makes it discoverable\n", "- improving confidence in PR correctness" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n", "\n", "IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Blog Post Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With sufficient effort, the logged data in the experiment management system\n", "can be made clear enough to be consumed,\n", "sufficiently contextualized to be useful outside the team, and\n", "even beautiful.\n", "\n", "The result is a report that's closer to a blog post than a dashboard or internal document.\n", "\n", "Use cases:\n", "- communication between teams or vertically in large organizations\n", "- external technical communication for branding and recruiting\n", "- attracting users or contributors\n", "\n", "Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma)\n", "and others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n", "\n", "IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many of our choices, like the depth of our network, the nonlinearities of our layers,\n", "and the learning rate and other parameters of our optimizer, cannot be\n", "([easily](https://arxiv.org/abs/1606.04474))\n", "chosen by descent of the gradient of a loss function.\n", "\n", "But these parameters that impact the values of the parameters\n", "we directly optimize with gradients, or _hyperparameters_,\n", "can still be optimized,\n", "essentially by trying options and selecting the values that worked best.\n", "\n", "In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n", "\n", "Expending more compute can squeeze small amounts of additional validation or test performance\n", "that makes for impressive results on leaderboards but typically doesn't translate\n", "into better user experience.\n", "\n", "In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n", "built into your other tooling.\n", "\n", "Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n", "([docs](https://docs.wandb.ai/guides/sweeps)).\n", "\n", "It also supports a number of more advanced tools, like\n", "[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n", "for early termination of poorly-performing runs.\n", "\n", "We can use the same training script and we don't need to run an optimization server.\n", "\n", "We just need to write a configuration yaml file\n", "([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n", "like the one below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training/simple-overfit-sweep.yaml\n", "# first we specify what we're sweeping\n", "# we specify a program to run\n", "program: training/run_experiment.py\n", "# we optionally specify how to run it, including setting default arguments\n", "command: \n", " - ${env}\n", " - ${interpreter}\n", " - ${program}\n", " - \"--wandb\"\n", " - \"--overfit_batches\"\n", " - \"1\"\n", " - \"--log_every_n_steps\"\n", " - \"25\"\n", " - \"--max_epochs\"\n", " - \"100\"\n", " - \"--limit_test_batches\"\n", " - \"0\"\n", " - ${args} # these arguments come from the sweep parameters below\n", "\n", "# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n", "method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n", "metric:\n", " name: train/loss\n", " goal: minimize\n", "parameters: \n", " # LineCNN hyperparameters\n", " window_width:\n", " values: [8, 16, 32, 64]\n", " window_stride:\n", " values: [4, 8, 16, 32]\n", " # Transformer hyperparameters\n", " tf_layers:\n", " values: [1, 2, 4, 8]\n", " # we can also fix some values, just like we set default arguments\n", " gpus:\n", " value: 1\n", " model_class:\n", " value: LineCNNTransformer\n", " data_class:\n", " value: IAMLines\n", " loss:\n", " value: transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the config we launch a \"controller\":\n", "a lightweight process that just decides what hyperparameters to try next\n", "and coordinates the heavierweight training.\n", "\n", "This lives on the W&B servers, so there are no headaches about opening ports for communication,\n", "cleaning up when it's done, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n", "simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we can launch an \"agent\" to follow the orders of the controller:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "\n", "# interrupt twice to terminate this cell if it's running too long,\n", "# it can be over 15 minutes with some hyperparameters\n", "\n", "!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n", "\n", "If not provided, the agent will run forever for random or Bayesian sweeps\n", "or until the sweep is terminated, which can be done from the W&B interface." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The agents make for a slick workflow for distributing sweeps across GPUs.\n", "\n", "We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n", "which controls which GPUs are accessible by a process, to launch\n", "parallel agents on separate GPUs on the same machine." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n", "# open another terminal\n", "CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n", "# and so on\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include optional exercises with the labs for learners who want to dive deeper on specific topics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟Contribute to a hyperparameter search." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n", "\n", "There are ~10,000,000 potential hyperparameter combinations,\n", "and each takes 30 minutes to test,\n", "so checking each possibility will take over 500 years of compute time.\n", "Best get cracking then!\n", "\n", "Run the cell below to pull up a dashboard and print the URL where you can check on the current status." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_entity = \"fullstackdeeplearning\"\n", "sweep_project = \"fsdl-line-recognizer-2022\"\n", "sweep_id = \"e0eo43eu\"\n", "sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n", "\n", "print(sweep_url)\n", "IFrame(src=sweep_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve information about the sweep from the API,\n", "including the hyperparameters being swept over." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparams = sweep_info.config[\"parameters\"]\n", "hyperparams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you'd like to contribute to this sweep,\n", "run the cell below after changing the count to a number greater than 0.\n", "\n", "Each iteration runs for 30 minutes if it does not crash,\n", "e.g. due to out-of-memory errors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "count = 0 # off by default, increase it to join in!\n", "\n", "if count:\n", " !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Write some manual logging in `wandb`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the FSDL Text Recognizer codebase,\n", "we almost exclusively log to W&B through Lightning,\n", "rather than through the `wandb` Python SDK.\n", "\n", "If you're interested in learning how to use W&B directly, e.g. with another training framework,\n", "try out this quick exercise that introduces the key players in the SDK." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n", "\n", "It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n", "\n", "Add W&B metric and artifact logging to this cell:\n", "- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n", "- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import random\n", "\n", "import wandb\n", "\n", "\n", "os.makedirs(\"logs\", exist_ok=True)\n", "\n", "project = \"trying-wandb\"\n", "config = {\"steps\": 50}\n", "\n", "\n", "with wandb.init(project=project, config=config) as run:\n", " steps = wandb.config[\"steps\"]\n", " \n", " for ii in range(steps):\n", " loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n", " \n", " with open(\"logs/hello.txt\", \"w\") as f:\n", " f.write(\"hello from wandb, my dudes!\")\n", " \n", " run_id = run.id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hello_run = wb_api.run(f\"{project}/{run_id}\")\n", "\n", "# check for logged loss data\n", "if \"loss\" not in hello_run.history().keys():\n", " print(\"loss not logged 🥲\")\n", "else:\n", " print(\"loss logged successfully 🥞\")\n", " if len(hello_run.history()[\"loss\"]) != steps:\n", " print(\"loss not logged on all steps 🥲\")\n", " else:\n", " print(\"loss logged on all steps 🥞\")\n", "\n", "artifacts = hello_run.logged_artifacts()\n", "\n", "# check for artifact with the right name\n", "if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n", " print(\"hello artifact not logged 🥲\")\n", "else:\n", " print(\"hello artifact logged successfully 🥞\")\n", " # check for the file inside the artifacts\n", " if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n", " print(\"could not find hello.txt 🥲\")\n", " else:\n", " print(\"hello.txt logged successfully 🥞\")\n", " \n", " \n", "hello_run" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you observe interesting phenomena during training,\n", "from promising hyperparameter combos to software bugs to strange model behavior,\n", "turn the charts into a W&B report and share it with the FSDL community or\n", "[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n", "with a link to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# check the sweep_info.config above to see the model and data hyperparameters\n", "# read through the --help output for all potential arguments\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n", " --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n", " --help # remove this line to run an experiment instead of printing help\n", " \n", "last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟🌟🌟 Add logging of tensor statistics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to logging model inputs and outputs as human-interpretable media,\n", "it's also frequently useful to see information about their numerical values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're interested in learning more about metric calculation and logging with Lightning,\n", "use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n", "to add tensor statistic logging to the `LineCNNTransformer`.\n", "\n", "`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n", "\n", "All three are useful, but start by adding just one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n", "- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n", " - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n", "- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n", " - Base your code on the calculation and logging of the `val_cer` metric.\n", " - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ", "collapsed_sections": [], "name": "lab04_experiments.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab06/notebooks/lab05_troubleshooting.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 05: Troubleshooting & Testing" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n", "- How to implement tests for ML training systems in particular\n", "- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 5\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sThWeTtV6fL_" }, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "xFP8lU4nSg1P" }, "source": [ "# Linting Python and Shell Scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "cXbdYfFlPhZ-" }, "source": [ "### Automatically linting with `pre-commit`" ] }, { "cell_type": "markdown", "metadata": { "id": "ysqqb2GjvLrz" }, "source": [ "We want keep our code clean and uniform across developers\n", "and time.\n", "\n", "Applying the cleanliness checks and style rules should be\n", "as painless and automatic as possible.\n", "\n", "For this purpose, we recommend bundling linting tools together\n", "and enforcing them on all commits with\n", "[`pre-commit`](https://pre-commit.com/)." ] }, { "cell_type": "markdown", "metadata": { "id": "XvqtZChKvLr0" }, "source": [ "In addition to running on every commit,\n", "`pre-commit` separates the model development environment from the environments\n", "needed for the linting tools, preventing conflicts\n", "and simplifying maintenance and onboarding." ] }, { "cell_type": "markdown", "metadata": { "id": "Y0XuIuKOXhJl" }, "source": [ "This cell runs `pre-commit`.\n", "\n", "The first time it is run on a machine, it will install the environments for all tools." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hltYGbpNvLr1" }, "outputs": [], "source": [ "!pre-commit run --all-files" ] }, { "cell_type": "markdown", "metadata": { "id": "gLw08gIkvLr1" }, "source": [ "The output lists all the checks that are run and whether they are passed.\n", "\n", "Notice there are a number of simple version-control hygiene practices included\n", "that aren't even specific to Python, much less to machine learning.\n", "\n", "For example, several of the checks prevent accidental commits with private keys, large files, \n", "leftover debugger statements, or merge conflict annotations in them." ] }, { "cell_type": "markdown", "metadata": { "id": "RHEEjb9kvLr1" }, "source": [ "These linting actions are configured via\n", "([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n", "a YAML file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dgXa8BzrvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml" ] }, { "cell_type": "markdown", "metadata": { "id": "8HYc_WbTvLr2" }, "source": [ "Most of the general cleanliness checks are from hooks built by `pre-commit`.\n", "\n", "See the comments and links in the `.pre-commit-config.yaml` for more:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K9rTgRqzvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep repos -A 15" ] }, { "cell_type": "markdown", "metadata": { "id": "1ptkO7aPvLr2" }, "source": [ "Let's take a look at the section of the file\n", "that applies most of our Python style enforcement with\n", "[`flake8`](https://flake8.pycqa.org/en/latest/):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ALsRKfcevLr3", "scrolled": true }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10" ] }, { "cell_type": "markdown", "metadata": { "id": "a_Q0BwQUXbg6" }, "source": [ "The majority of the style checking behavior we want comes from the\n", "`additional_dependencies`, which are\n", "[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n", "that extend `flake8`'s list of lints.\n", "\n", "Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n", "\n", "We keep the configuration information for `flake8`\n", "separate from that for `pre-commit`\n", "in case we want to use additional tools with `flake8`,\n", "e.g. if some developers want to integrate it directly into their editor,\n", "and so that if we change away from `.pre-commit`\n", "but keep `flake8` we don't have to\n", "recreate our configuration in a different tool.\n", "\n", "As much as possible, codebases should strive for single sources of truth\n", "and link back to those sources of truth with documentation or comments,\n", "as in the last line above.\n", "\n", "Let's take a look at the contents of `flake8`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "doC_4WQwvLr3" }, "outputs": [], "source": [ "!cat .flake8" ] }, { "cell_type": "markdown", "metadata": { "id": "0Nq6HnyU0M47" }, "source": [ "There's a lot here! We'll focus on the most important bits." ] }, { "cell_type": "markdown", "metadata": { "id": "U4PiB8CPvLr3" }, "source": [ "Linting tools in Python generally work by emitting error codes\n", "with one or more letters followed by three numbers.\n", "The `select` argument picks which error codes we want to check for.\n", "Error codes are matched by prefix,\n", "so for example `B` matches `BTS101` and\n", "`G1` matches `G102` and `G199` but not `ARG404`.\n", "\n", "Certain codes are `ignore`d in the default `flake8` style,\n", "which is done via the `ignore` argument,\n", "and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n", "For example, we rely on `black` to do our formatting,\n", "so we ignore some of `flake8`'s formatting codes.\n", "\n", "Together, these settings define our project's particular style.\n", "\n", "But not every file fits this style perfectly.\n", "Most of the conventions in `black` and `flake8` come from the style-defining\n", "[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n", "which exhorts you to \"know when to be inconsistent\".\n", "\n", "To allow ourselves to be inconsistent when we know we should be,\n", "`flake8` includes `per-file-ignores`,\n", "which let us ignore specific warnings in specific files.\n", "This is one of the \"escape valves\"\n", "that makes style enforcement tolerable.\n", "We can also `exclude` files in the `pre-commit` config itself.\n", "\n", "For details on selecting and ignoring,\n", "see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n", "\n", "For definitions of the error codes from `flake8` itself,\n", "see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n", "Individual extensions list their added error codes in their documentation,\n", "e.g. `darglint` does so\n", "[here](https://github.com/terrencepreilly/darglint#error-codes)." ] }, { "cell_type": "markdown", "metadata": { "id": "NL0TpyPsvLr4" }, "source": [ "The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n", "\n", "You can read more about each in their documentation:\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "markdown", "metadata": { "id": "mFsZC0a7vLr4" }, "source": [ "### Linting via a script and using `shellcheck`" ] }, { "cell_type": "markdown", "metadata": { "id": "RYjpuFwjXkJc" }, "source": [ "To avoid needing to think about `pre-commit`\n", "(was the command `pre-commit run` or `pre-commit check`?)\n", "while developing locally,\n", "we might put our linters into a shell script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXlLFWmavLr4" }, "outputs": [], "source": [ "!cat tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "PPxHpRIB3nbw" }, "source": [ "These kinds of short and simple shell scripts are common in projects\n", "of intermediate size.\n", "\n", "They are useful for adding automation and reducing friction." ] }, { "cell_type": "markdown", "metadata": { "id": "TMuPBpAi2qwl" }, "source": [ "But these scripts are code,\n", "and all code is susceptible to bugs and subject to concerns of style consistency." ] }, { "cell_type": "markdown", "metadata": { "id": "SQRg3ZqXvLr4" }, "source": [ "We can't check these scripts with tools that lint Python code,\n", "so we include a shell script linting tool,\n", "[`shellcheck`](https://www.shellcheck.net/),\n", "in our `pre-commit`.\n", "\n", "More so than checking for correct style,\n", "this tool checks for common bugs or surprising behaviors of shells,\n", "which are unfortunately numerous." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zkfhE1srvLr4" }, "outputs": [], "source": [ "script_filename = \"tasks/lint.sh\"\n", "!pre-commit run shellcheck --files {script_filename}" ] }, { "cell_type": "markdown", "metadata": { "id": "KXU9TRrwvLr4" }, "source": [ "That script has already been tested, so we don't see any errors.\n", "\n", "Try copying over a script you've written yourself or\n", "even from a popular repo that you like\n", "(by adding to the notebook directory or by making a cell\n", "with `%%writefile` at the top)\n", "and test it by changing the `script_filename`.\n", "\n", "You'd be surprised at the classes of subtle bugs possible in bash!" ] }, { "cell_type": "markdown", "metadata": { "id": "81MhAL-TvLr5" }, "source": [ "### Try \"unofficial bash strict mode\" for louder failures in scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "hSwhs_zUvLr5" }, "source": [ "Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n", "[@redsymbol](https://twitter.com/redsymbol),\n", "which appear at the top of the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o-j0vSxEvLr5" }, "outputs": [], "source": [ "!head -n 3 tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "d2iJU5jlvLr5" }, "source": [ "The core idea of strict mode is to fail more loudly.\n", "This is a desirable behavior of scripts,\n", "like the ones we're writing,\n", "even though it's an undesirable behavior for an interactive shell --\n", "it would be unpleasant to be logged out every time you hit an error.\n", "\n", "`set -u` means scripts fail if a variable's value is `u`nset,\n", "i.e. not defined.\n", "Otherwise bash is perfectly happy to allow you to reference undefined variables.\n", "The result is just an empty string, which can lead to maddeningly weird behavior.\n", "\n", "`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n", "rather than using the exit code of the last command.\n", "Unix tools are perfectly happy to work on nonsense input,\n", "like sorting error messages, instead of the filenames you meant to send.\n", "\n", "You can read more about these choices\n", "[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n", "and considerations for working with other non-conforming scripts in \"strict mode\"\n", "and for handling resource teardown when scripts error out." ] }, { "cell_type": "markdown", "metadata": { "id": "s1XqsrU_XWWS" }, "source": [ "# Testing ML Codebases" ] }, { "cell_type": "markdown", "metadata": { "id": "CPNzeq3NYF2W" }, "source": [ "## Testing Python code with `pytests`" ] }, { "cell_type": "markdown", "metadata": { "id": "zq5e_x6gc9Vu" }, "source": [ "\n", "ML codebases are Python first and foremost, so first let's get some Python tests going." ] }, { "cell_type": "markdown", "metadata": { "id": "0DC3GxYz6_R9" }, "source": [ "At a basic level,\n", "we can write functions that `assert`\n", "that our code behaves as expected in\n", "a given scenario and include it in the same module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rvd-GNwv63W1" }, "outputs": [], "source": [ "from text_recognizer.lit_models.metrics import test_character_error_rate\n", "\n", "test_character_error_rate??" ] }, { "cell_type": "markdown", "metadata": { "id": "iVB2TsQS5BTq" }, "source": [ "The standard tool for testing Python code is\n", "[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n", "\n", "We can use it as a command-line tool in a variety of ways,\n", "including to execute these kinds of tests.\n", "\n", "If passed a filename, `pytest` will look for\n", "any classes that start with `Test` or\n", "any functions that start with `test_` and run them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u8sQguyJvLr6", "scrolled": false }, "outputs": [], "source": [ "!pytest text_recognizer/lit_models/metrics.py" ] }, { "cell_type": "markdown", "metadata": { "id": "92tkBCllvLr6" }, "source": [ "After the results of the tests (pass or fail) are returned,\n", "you'll see a report of \"coverage\" from\n", "[`codecov`](https://about.codecov.io/).\n", "\n", "This coverage report tells us which files and how many lines in those files\n", "were at touched by the testing suite." ] }, { "cell_type": "markdown", "metadata": { "id": "PllSUe0s5xvU" }, "source": [ "We do not actually need to provide the names of files with tests in them to `pytest`\n", "in order for it to run our tests." ] }, { "cell_type": "markdown", "metadata": { "id": "4qOBHJnTZM9x" }, "source": [ "By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n", "\n", "It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n", "to separate these from the rest of your code\n", "in a folder or folders named `tests`,\n", "rather than scattering them around the repo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "acjsYTNSvLr6" }, "outputs": [], "source": [ "!ls text_recognizer/tests" ] }, { "cell_type": "markdown", "metadata": { "id": "WZQQZUF0vLr6" }, "source": [ "Let's take a look at a specific example:\n", "the tests for some of our utilities around\n", "custom PyTorch Lightning `Callback`s." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oS0xKv1evLr6" }, "outputs": [], "source": [ "from text_recognizer.tests import test_callback_utils\n", "\n", "\n", "test_callback_utils.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "lko8msn-vLr7" }, "source": [ "Notice that we can easily import this as a module!\n", "\n", "That's another benefit of organizing tests into specialized files." ] }, { "cell_type": "markdown", "metadata": { "id": "5A85FUNv75Fr" }, "source": [ "The particular utility we're testing\n", "here is designed to prevent crashes:\n", "it checks for a particular type of error and turns it into a warning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jl4-DiVe76sw" }, "outputs": [], "source": [ "from text_recognizer.callbacks.util import check_and_warn\n", "\n", "check_and_warn??" ] }, { "cell_type": "markdown", "metadata": { "id": "B6E0MhduvLr7" }, "source": [ "Error-handling code is a common cause of bugs,\n", "a fact discovered\n", "[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n", "so it's very important to test it well!\n", "\n", "We start with a very basic test,\n", "which does not touch anything\n", "outside of the Python standard library,\n", "even though this tool is intended to be used\n", "with more complex features of third-party libraries,\n", "like `wandb` and `tensorboard`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xx5koQmJvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_simple??" ] }, { "cell_type": "markdown", "metadata": { "id": "MZe9-JVjvLr7" }, "source": [ "Here, we are just testing the core logic.\n", "This test won't catch many bugs,\n", "but when it does fail, something has gone seriously wrong.\n", "\n", "These kinds of tests are important for resolving a bug:\n", "we learn nearly as much from the tests that passed\n", "as we did from the tests that failed.\n", "If this test has failed, possibly along with others,\n", "we can rule out an issue in one of the large external codebases\n", "touched in the other tests, saving us lots of time in our troubleshooting.\n", "\n", "The reasoning for the test is explained in the docstrings, \n", "which are close to the code.\n", "\n", "Your test suite should be as welcoming\n", "as the rest of your codebase!\n", "The people reading it, for example yourself in six months, \n", "are likely upset and in need of some kindness.\n", "\n", "More practically, we want keep our time to resolve errors as short as possible,\n", "and five minutes to write a good docstring now\n", "can save five minutes during an outage, when minutes really matter." ] }, { "cell_type": "markdown", "metadata": { "id": "Om9k-uXhvLr7" }, "source": [ "That basic test is a start, but it's not enough by itself.\n", "There's a specific error case that triggered the addition of this code.\n", "\n", "So we test that it's handled as expected." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjbsb5FvvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_tblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "CGAIZTUjvLr7" }, "source": [ "That test can fail if the libraries change around our code,\n", "i.e. if the `TensorBoardLogger` gets a `log_table` method.\n", "\n", "We want to be careful when making assumptions\n", "about other people's software,\n", "especially for fast-moving libraries like Lightning.\n", "If we test that those assumptions hold willy-nilly,\n", "we'll end up with tests that fail because of\n", "harmless changes in our dependencies.\n", "\n", "Tests that require a ton of maintenance and updating\n", "without leading to code improvements soak up\n", "more engineering time than they save\n", "and cause distrust in the testing suite.\n", "\n", "We include this test because `TensorBoardLogger` getting\n", "a `log_table` method will _also_ change the behavior of our code\n", "in a breaking way, and we want to catch that before it breaks\n", "a model training job." ] }, { "cell_type": "markdown", "metadata": { "id": "jsy95KAvvLr7" }, "source": [ "Adding error handling can also accidentally kill the \"happy path\"\n", "by raising an error incorrectly.\n", "\n", "So we explicitly test the _absence of an error_,\n", "not just its presence:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LRlIOkjmvLr8" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_wandblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "osiqpLynvLr8" }, "source": [ "There are more tests we could build, e.g. manipulating classes and testing the behavior,\n", "testing more classes that might be targeted by `check_and_warn`, or\n", "asserting that warnings are raised to the command line.\n", "\n", "But these three basic tests are likely to catch most changes that would break our code here,\n", "and they're a lot easier to write than the others.\n", "\n", "If this utility starts to get more usage and become a critical path for lots of features, we can always add more!" ] }, { "cell_type": "markdown", "metadata": { "id": "dm285JE5vLr8" }, "source": [ "## Interleaving testing and documentation with `doctests`" ] }, { "cell_type": "markdown", "metadata": { "id": "UHWQvgA8vLr8" }, "source": [ "One function of tests is to build user/reader confidence in code." ] }, { "cell_type": "markdown", "metadata": { "id": "wrhiJBXFvLr8" }, "source": [ "One function of documentation is to build user/reader knowledge in code." ] }, { "cell_type": "markdown", "metadata": { "id": "1vu12LDhvLr8" }, "source": [ "These functions are related. Let's put them together:\n", "put code in a docstring and test that code.\n", "\n", "This feature is part of the\n", "Python standard library via the\n", "[`doctest` module](https://docs.python.org/3/library/doctest.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "rmfIOwXd-Qt7" }, "source": [ "Here's an example from our `torch` utilities.\n", "\n", "The `first_appearance` function can be used to\n", "e.g. quickly look for stop tokens,\n", "giving the length of each sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZzURGcD9vLr8" }, "outputs": [], "source": [ "from text_recognizer.lit_models.util import first_appearance\n", "\n", "\n", "first_appearance??" ] }, { "cell_type": "markdown", "metadata": { "id": "0VtYcJ1WvLr8" }, "source": [ "Notice that in the \"Examples\" section,\n", "there's a short block of code formatted as a\n", "Python interpreter session,\n", "complete with outputs.\n", "\n", "We can copy and paste that code and\n", "check that we get the right outputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dj4lNOxJvLr9" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y9AWHFoIvLr9" }, "source": [ "We can run the test with `pytest` by passing a command line argument,\n", "`--doctest-modules`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JMaAxv5ovLr9" }, "outputs": [], "source": [ "!pytest --doctest-modules text_recognizer/lit_models/util.py" ] }, { "cell_type": "markdown", "metadata": { "id": "6-2_aOUfvLr9" }, "source": [ "With the\n", "[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n", "running `doctest`s happens automatically\n", "when `pytest` is invoked." ] }, { "cell_type": "markdown", "metadata": { "id": "my_keokPvLr9" }, "source": [ "## Basic tests for data code" ] }, { "cell_type": "markdown", "metadata": { "id": "Qj3Bq_j2_A8o" }, "source": [ "ML code can be hard to test\n", "since it involes very heavy artifacts, like models and data,\n", "and very expensive jobs, like training." ] }, { "cell_type": "markdown", "metadata": { "id": "DT5OmgrQvLr9" }, "source": [ "For testing our data-handling code in the FSDL codebase,\n", "we mostly just use `assert`s,\n", "which throw errors when behavior differs from expectation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bdzn5g4TvLr9" }, "outputs": [], "source": [ "!grep \"assert\" -r text_recognizer/data" ] }, { "cell_type": "markdown", "metadata": { "id": "2aTlfu4_vLr-" }, "source": [ "This isn't great practice,\n", "especially as a codebase grows,\n", "because we can't easily know when these are executed\n", "or incorporate them into\n", "testing automation and coverage analysis tools." ] }, { "cell_type": "markdown", "metadata": { "id": "IaMTdmbZ_mkW" }, "source": [ "So it's preferable to collect up these assertions of simple data properties\n", "into tests that are run like our other tests.\n", "\n", "The test below checks whether any data is leaking\n", "between training, validation, and testing." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qx7cxiDdvLr-" }, "outputs": [], "source": [ "from text_recognizer.tests.test_iam import test_iam_data_splits\n", "\n", "\n", "test_iam_data_splits??" ] }, { "cell_type": "markdown", "metadata": { "id": "16TJwhd1vLr-" }, "source": [ "Notice that we were able to load the test into the notebook\n", "because it is in a module,\n", "and so we can run it here as well:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mArITFkYvLr-" }, "outputs": [], "source": [ "test_iam_data_splits()" ] }, { "cell_type": "markdown", "metadata": { "id": "E4F2uaclvLr-" }, "source": [ "But we're checking something pretty simple here,\n", "so the new code in each test is just a single line.\n", "\n", "What if we wanted to test more complex properties,\n", "like comparing rows or calculating statistics?\n", "\n", "We'll end up writing more complex code that might itself have subtle bugs,\n", "requiring tests for our tests and suffering from\n", "\"tester's regress\".\n", "\n", "This is the phenomenon,\n", "named by analogy with\n", "[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n", "in sociology of science,\n", "where the validity of our tests is itself\n", "up for dispute only resolvable by testing the tests,\n", "but those tests are themselves possibly invalid." ] }, { "cell_type": "markdown", "metadata": { "id": "nUGT06gdvLr-" }, "source": [ "We cut this Gordian knot by using\n", "a library or framework that is well-tested.\n", "\n", "We recommend checking out\n", "[`great_expectations`](https://docs.greatexpectations.io/docs/)\n", "if you're looking for a high-quality data testing tool." ] }, { "cell_type": "markdown", "metadata": { "id": "dQ5vNsq3vLr-" }, "source": [ "Especially with data, some tests are particularly \"heavy\" --\n", "they take a long time,\n", "and we might want to run them\n", "on different machines\n", "and on a different schedule\n", "than our other tests." ] }, { "cell_type": "markdown", "metadata": { "id": "xephcb0LvLr-" }, "source": [ "For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n", "\n", "We can't just use a cached version of the data,\n", "since that won't actually execute the code!\n", "\n", "This test will take\n", "as long to run\n", "and consume as many resources as\n", "a full download of the data." ] }, { "cell_type": "markdown", "metadata": { "id": "YSN4w2EqvLr-" }, "source": [ "`pytest` allows the separation of tests\n", "into suites with `mark`s,\n", "which \"tag\" tests with names." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V0rScrcXvLr_", "scrolled": false }, "outputs": [], "source": [ "!pytest --markers | head -n 10" ] }, { "cell_type": "markdown", "metadata": { "id": "lr5Ca7B0vLr_" }, "source": [ "We can choose to run tests with a given mark\n", "or to skip tests with a given mark, \n", "among other basic logical operations around combining and filtering marks,\n", "with `-m`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xmw-Eb1ZvLr_" }, "outputs": [], "source": [ "!wandb login # one test requires wandb authentication\n", "\n", "!pytest -m \"not data and not slow\"" ] }, { "cell_type": "markdown", "metadata": { "id": "5LuERxOXX_UJ" }, "source": [ "## Testing training with memorization tests" ] }, { "cell_type": "markdown", "metadata": { "id": "AnWLN4lRvLsA" }, "source": [ "Training is the process by which we convert inert data into executable models,\n", "so it is dependent on both.\n", "\n", "We decouple checking whether the script has a critical bug\n", "from whether the data or model code is broken\n", "by testing on some basic \"fake data\",\n", "based on a utility from `torchvision`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k4NIc3uWvLsA" }, "outputs": [], "source": [ "from text_recognizer.data import FakeImageData\n", "\n", "\n", "FakeImageData.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "deN0swwlvLsA" }, "source": [ "We then test on the actual data with a smaller version of the real model.\n", "\n", "We use the Lightning `--fast_dev_run` feature,\n", "which sets the number of training, validation, and test batches to `1`.\n", "\n", "We use a smaller version so that this test can run in just a few minutes\n", "on a CPU without acceleration.\n", "\n", "That allows us to run our tests in environments without GPUs,\n", "which saves on costs for executing tests.\n", "\n", "Here's the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z4J0_uD9vLsA" }, "outputs": [], "source": [ "!cat training/tests/test_run_experiment.sh" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-7u9zS1vLsA", "scrolled": false }, "outputs": [], "source": [ "! ./training/tests/test_run_experiment.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "UTzfo11KClV3" }, "source": [ "The above tests don't actaully check\n", "whether any learning occurs,\n", "they just check\n", "whether training runs mechanically,\n", "without any errors.\n", "\n", "We also need a\n", "[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n", "for learning.\n", "For that we recommending checking whether\n", "the model can learn the right\n", "outputs for a single batch --\n", "to \"memorize\" the outputs for\n", "a particular input.\n", "\n", "This memorization test won't\n", "catch every bug or issue in training,\n", "which is notoriously difficult,\n", "but it will flag\n", "some of the most serious issues." ] }, { "cell_type": "markdown", "metadata": { "id": "0DVSp3aAvLsA" }, "source": [ "The script below runs a memorization test." ] }, { "cell_type": "markdown", "metadata": { "id": "2DFVVrxpvLsA" }, "source": [ "It takes up to two arguments:\n", "a `MAX`imum number of `EPOCHS` to run for and\n", "a `CRITERION` value of the loss to test against.\n", "\n", "The test passes if the loss is lower than the `CRITERION` value\n", "after the `MAX`imum number of `EPOCHS` has passed." ] }, { "cell_type": "markdown", "metadata": { "id": "oEhJH0e5vLsB" }, "source": [ "The important line in this script is the one that invokes our training script,\n", "`training/run_experiment.py`.\n", "\n", "The arguments to `run_experiment` have been tuned for maximum possible speed:\n", "turning off regularization, shrinking the model,\n", "and skipping parts of Lightning that we don't want to test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T-fFs1xEvLsB" }, "outputs": [], "source": [ "!cat training/tests/test_memorize_iam.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "X-47tUA_YNGe" }, "source": [ "If you'd like to see what a memorization run looks like,\n", "flip the `running_memorization` flag to `True`\n", "and watch the results stream in to W&B.\n", "\n", "The cell should run in about ten minutes on a commodity GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GwTEsZwKvLsB" }, "outputs": [], "source": [ "%%time\n", "running_memorization = False\n", "\n", "if running_memorization:\n", " max_epochs = 1000\n", " loss_criterion = 0.05\n", " !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Troubleshooting model speed with the PyTorch Profiler" ] }, { "cell_type": "markdown", "metadata": { "id": "DpbN-Om2Drf-" }, "source": [ "Testing code is only half the story here:\n", "we also need to fix the issues that our tests flag.\n", "This is the process of troubleshooting.\n", "\n", "In this lab,\n", "we'll focus on troubleshooting model performance issues:\n", "what do to when your model runs too slowly." ] }, { "cell_type": "markdown", "metadata": { "id": "NZzwELPXvLsD" }, "source": [ "Troubleshooting deep neural networks for speed is challenging.\n", "\n", "There are at least three different common approaches,\n", "each with an increasing level of skill required:\n", "\n", "1. Follow best practices advice from others\n", "([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n", "[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n", "2. Take code that runs slowly and use empirical observations to iteratively improve it.\n", "3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n", "\n", "For the full stack deep learning engineer,\n", "the final level is typically out of reach,\n", "unless you're specializing in the model performance\n", "part of the stack in particular.\n", "\n", "So we recommend reaching the middle level,\n", "and this segment of the lab walks through the\n", "tools that make this easier." ] }, { "cell_type": "markdown", "metadata": { "id": "3_yp87UrFZ8M" }, "source": [ "Because neural network training involves GPU acceleration,\n", "generic Python profiling tools like\n", "[`py-spy`](https://github.com/benfred/py-spy)\n", "won't work, and\n", "we'll need tools specialized for tracing and profiling DNN training." ] }, { "cell_type": "markdown", "metadata": { "id": "yspsYVFGEyZm" }, "source": [ "In general, these tools are for observing what happens while your code is executing:\n", "_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n", "\n", "Because they help us observe the execution in detail,\n", "they will also help us understand just what is going on during\n", "a PyTorch training step in greater detail." ] }, { "cell_type": "markdown", "metadata": { "id": "YqXq2hKuvLsE" }, "source": [ "To support profiling and tracing,\n", "we've added a new argument to `training/run_experiment.py`, `--profile`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z_GMMViWvLsE" }, "outputs": [], "source": [ "!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\"" ] }, { "cell_type": "markdown", "metadata": { "id": "ZldoksHPvLsE" }, "source": [ "As with experiment management, this relies mostly on features of PyTorch Lightning,\n", "which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n", "and we just add a few lines of customization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F2iJ0_A6vLsE" }, "outputs": [], "source": [ "!cat training/run_experiment.py | grep args.profile -A 5" ] }, { "cell_type": "markdown", "metadata": { "id": "Aw3ppgndvLsE" }, "source": [ "For more on profiling with Lightning, see the\n", "[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCAmNW3QEtcD" }, "source": [ "The cell below runs an epoch of training with tracing and profiling turned on\n", "and then saves the results locally and to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t4o3ylDgr46F", "scrolled": false }, "outputs": [], "source": [ "import glob\n", "\n", "import torch\n", "import wandb\n", "\n", "from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n", "\n", "\n", "# make it easier to separate these from training runs\n", "%env WANDB_JOB_TYPE=profile\n", "\n", "batch_size = 16\n", "num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n", "gpus = 1 # must be run with accelerator\n", "\n", "%run training/run_experiment.py --wandb --profile \\\n", " --max_epochs=1 \\\n", " --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n", " --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n", " --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n", "\n", "latest_expt = wandb.run\n", "\n", "try: # add execution trace to logged and versioned binaries\n", " folder = wandb.run.dir\n", " trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n", " trace_file = glob.glob(trace_matcher)[0]\n", " trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n", " trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n", " wandb.log_artifact(trace_at)\n", "except IndexError:\n", " print(\"trace not found\")\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": { "id": "ePTkS3EqO5tN" }, "source": [ "We get out a table of statistics in the terminal,\n", "courtesy of Lightning.\n", "\n", "Each row lists an operation\n", "and and provides information,\n", "described in the column headers,\n", "about the time spent on that operation\n", "across all the training steps we profiled.\n", "\n", "With practice, some useful information can be read out from this table,\n", "but it's better to start from both a less detailed view,\n", "in the TensorBoard dashboard,\n", "and a more detailed view,\n", "using the Chrome Trace viewer." ] }, { "cell_type": "markdown", "metadata": { "id": "TzV62f3c7-Bi" }, "source": [ "## High-level statistics from the PyTorch Profiler in TensorBoard" ] }, { "cell_type": "markdown", "metadata": { "id": "mNPKXkYw8NWd" }, "source": [ "Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CbItwuT88eAV" }, "outputs": [], "source": [ "your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n", "\n", "print(your_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "jE_LooMYHFpF" }, "source": [ "If at any point you run into issues,\n", "like the description not matching what you observe,\n", "check out one of our example runs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "za2zybSwIo5C" }, "outputs": [], "source": [ "example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n", "print(example_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "xlrhl1n4HYU6" }, "source": [ "Once the TensorBoard session has loaded up,\n", "we are dropped into the Overview\n", "(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n", "for an example).\n", "\n", "In the top center, we see the **GPU Summary** for our system.\n", "\n", "In addition to the name of our GPU,\n", "there are a few configuration details and top-level statistics.\n", "They are (tersely) documented\n", "[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)." ] }, { "cell_type": "markdown", "metadata": { "id": "MmBhUDgDLhd1" }, "source": [ "- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n", "this is effectively a coarse \"version number\" for your GPU hardware.\n", "It indexes which features are available,\n", "with more advanced features being available only at higher compute capabilities.\n", "It does not directly index the speed or memory of the GPU." ] }, { "cell_type": "markdown", "metadata": { "id": "voUgT6zuLyi0" }, "source": [ "- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase." ] }, { "cell_type": "markdown", "metadata": { "id": "Yl-IndtXE4b4" }, "source": [ "- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n", "for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n", "Tensor Cores.\n", "If you're running on an older GPU without Tensor Cores,\n", "you should consider upgrading.\n", "If you're running a more recent GPU but not seeing Tensor Core usage,\n", "you should switch to single precision floating point numbers,\n", "which Tensor Cores are specialized on." ] }, { "cell_type": "markdown", "metadata": { "id": "XxcUf0bBNXy_" }, "source": [ "- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n", "at a lower level than just whether something is running at all,\n", "as in utilization.\n", "Unlike utilization, reaching 100% is not generally feasible\n", "and sometimes not desirable.\n", "Increasing these numbers requires expertise in\n", "CUDA programming, so we'll target utilization instead." ] }, { "cell_type": "markdown", "metadata": { "id": "A88pQn4YMMKc" }, "source": [ "- **Execution Summary**: This table and pie chart indicates\n", "how much time within a profiled step\n", "was spent in each category.\n", "The value for \"kernel\" execution here\n", "is equal to the GPU utilization,\n", "and we want that number to be as close to 100%\n", "as possible.\n", "This summary helps us know which\n", "other operations are taking time,\n", "like memory being copied between CPU and GPU (`memcpy`)\n", "or `DataLoader`s executing on the CPU,\n", "so we can decide where the bottleneck is." ] }, { "cell_type": "markdown", "metadata": { "id": "6qjW1RlTQRPv" }, "source": [ "At the very bottom, you'll find a\n", "**Performance Recommendation**\n", "tab that sometimes suggests specific methods for improving performance.\n", "\n", "If this tab makes suggestions, you should certainly take them!" ] }, { "cell_type": "markdown", "metadata": { "id": "pWY5AhrcRQmJ" }, "source": [ "For more on using the profiler in TensorBoard,\n", "including some of the other, more detailed views\n", "available view the \"Views\" dropdown menu, see\n", "[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)." ] }, { "cell_type": "markdown", "metadata": { "id": "mQwrPY_H77H8" }, "source": [ "## Going deeper with the Chrome Trace Viewer" ] }, { "cell_type": "markdown", "metadata": { "id": "yhwo7fslvLsH" }, "source": [ "So far, we've seen summary-level information about our training steps\n", "in the table from Lightning and in the TensorBoard Overview.\n", "These give aggregate statistics about the computations that occurred,\n", "but understanding how to interpret those statistics\n", "and use them to speed up our networks\n", "requires understanding just what is\n", "happening in our training step.\n", "\n", "Fundamentally,\n", "all computations are processes that unfold in time.\n", "\n", "If we want to really understand our training step,\n", "we need to display it that way:\n", "what operations were occurring,\n", "on both the CPU and GPU,\n", "at each moment in time during the training step.\n", "\n", "This information on timing is collected in the trace.\n", "One of the best tools for viewing the trace over time\n", "is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)." ] }, { "cell_type": "markdown", "metadata": { "id": "wUkZItxYc20A" }, "source": [ "Let's tour the trace we just logged\n", "with an aim to really understanding just\n", "what is happening when we call\n", "`training_step`\n", "and by extension `.forward`, `.backward`, and `optimizer.step`." ] }, { "cell_type": "markdown", "metadata": { "id": "9w9F2UA7Qctg" }, "source": [ "The Chrome Trace Viewer is built into W&B,\n", "so we can view our traces in their interface.\n", "\n", "The cell below embeds the trace inside the notebook,\n", "but you may wish to open it separately,\n", "with the \"Open page\" button or by navigating to the URL,\n", "so that you can interact with it\n", "as you read the description below.\n", "Display directly on W&B is also a bit less temperamental\n", "than display on W&B inside a notebook.\n", "\n", "Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n", "so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n", "It also can interact poorly with browser extensions (e.g. ad blockers),\n", "so you may need to deactivate them temporarily in order to see it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OMUs4aby6Rfd" }, "outputs": [], "source": [ "trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n", "trace_url = trace_files_url + \"training_step.pt.trace.json\"\n", "\n", "example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n", "\n", "print(trace_url)\n", "IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qNVpGeQtQjMG" }, "source": [ "> **Heads up!** We're about to do a tour of the\n", "> precise details of the tracing information logged\n", "> during the execution of the training code.\n", "> The only way to learn how to troubleshoot model performance\n", "> empirically is to look at the details,\n", "> but the details depend on the precise machine being used\n", "> -- GPU and CPU and RAM.\n", "> That means even within Colab,\n", "> these details change from session to session.\n", "> So if you don't observe a phenomenon or feature\n", "> described in the tour below, check out\n", "> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n", "> on W&B while reading through the next section of the lab,\n", "> and return to your trace once you understand the trace viewer better at the end.\n", "> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n", "> can sometimes be a bit janky." ] }, { "cell_type": "markdown", "metadata": { "id": "kXMcBhnCgdN_" }, "source": [ "This trace reveals, in nanosecond-level detail,\n", "what's going on inside of a `training_step`\n", "on both the GPU and the CPU.\n", "\n", "Time is on the horizontal axis.\n", "Colored bars represent method calls,\n", "and the methods called by a method are placed underneath it vertically,\n", "a visualization known as an\n", "[icicle chart](https://www.brendangregg.com/flamegraphs.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "67BsNzDfVIeg" }, "source": [ "Let's orient ourselves with some gross features:\n", "the forwards pass,\n", "GPU kernel execution,\n", "the backwards pass,\n", "and the optimizer step." ] }, { "cell_type": "markdown", "metadata": { "id": "IBEFgtRCKqrh" }, "source": [ "### The forwards pass" ] }, { "cell_type": "markdown", "metadata": { "id": "5nYhiWesVMjK" }, "source": [ "Type in `resnet` to the search bar in the top-right.\n", "\n", "This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n", "\n", "It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n", "\n", "You can click the arrows next to that tile to partially collapse these blocks.\n", "\n", "Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n", "It should be at roughly the same height.\n", "\n", "Clear the search bar so that the trace is in color.\n", "Zoom in on the area of the forwards pass\n", "using the \"zoom\" tool in the floating toolbar,\n", "so you can see more detail.\n", "The zoom tool is indicated by a two-headed arrow\n", "pointing into and out of the screen.\n", "\n", "Switch to the \"drag\" tool,\n", "represented by a four-headed arrow.\n", "Click-and-hold to use this tool to focus\n", "on different parts of the timeline\n", "and click on the individual colored boxes\n", "to see details about a particular method call.\n", "\n", "As we go down in the icicle chart,\n", "we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n", "to much more precise `cudnn` and `cuda` operations\n", "(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n", "\n", "`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n", "is the tensor math library in PyTorch\n", "that links to specific backends like `cudnn`." ] }, { "cell_type": "markdown", "metadata": { "id": "Fq181ybIvLsH" }, "source": [ "### GPU kernel execution" ] }, { "cell_type": "markdown", "metadata": { "id": "IbkWp5aKvLsH" }, "source": [ "Towards the bottom, you should see a section labeled \"GPU\".\n", "The label appears on the far left.\n", "\n", "Within it, you'll see one or more \"`stream`s\".\n", "These are units of work on a GPU,\n", "akin loosely to threads on the CPU.\n", "\n", "When there are colored bars in this area,\n", "the GPU is doing work of some kind.\n", "The fraction of this bar that is filled in with color\n", "is the same as the \"GPU Utilization %\" we've seen previously.\n", "So the first thing to visually assess\n", "in a trace view of PyTorch code\n", "is what fraction of this area is filled with color.\n", "\n", "In CUDA, work is queued up to be\n", "placed into streams and completed, on the GPU,\n", "in a distributed and asynchronous manner.\n", "\n", "The selection of which work to do\n", "is happening on the CPU,\n", "and that's what we were looking at above.\n", "\n", "The CPU and the GPU have to work together to coordinate\n", "this work.\n", "\n", "Type `cuda` into the search bar and you'll see these coordination operations happening:\n", "`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n", "\n", "Running the same PyTorch model\n", "with the same high level operations like `Conv2d` in different versions of PyTorch,\n", "on different GPUs, and even on tensors of different sizes will result\n", "in different choices of concrete kernel operation,\n", "e.g. different matrix multiplication algorithms.\n", "\n", "Type `sync` into the search bar and you'll see places where either work on the GPU\n", "or work on the CPU needs to await synchronization,\n", "e.g. copying data from the CPU to the GPU\n", "or the CPU waiting to decide what to do next\n", "on the basis of the contents of a tensor.\n", "\n", "If you see a \"sync\" block above an area\n", "where the stream on the GPU is empty,\n", "you've got a performance bottleneck due to synchronization\n", "between the CPU and GPU.\n", "\n", "To resolve the bottleneck,\n", "head up the icicle chart until you reach the recognizable\n", "PyTorch modules and operations.\n", "Find where they are called in your PyTorch module.\n", "That's a good place to review your code to understand why the synchronization is happening\n", "and removing it if it's not necessary." ] }, { "cell_type": "markdown", "metadata": { "id": "XeMPbu_jvLsI" }, "source": [ "### The backwards pass\n", "\n", "Type in `backward` into the search bar.\n", "\n", "This will highlight components of our backwards pass.\n", "\n", "If you read it from left to right,\n", "you'll see that it begins by calculating the loss\n", "(`NllLoss2DBackward` in the search bar if you can't find it)\n", "and ends by doing a `ConvolutionBackward`,\n", "the first layer of the ResNet.\n", "It is, indeed, backwards.\n", "\n", "Like the forwards pass,\n", "the backwards pass also involves the CPU\n", "telling the GPU which kernels to run.\n", "It's typically run in a separate\n", "thread from the forwards pass,\n", "so you'll see it separated out from the forwards pass\n", "in the trace viewer.\n", "\n", "Generally, there's no need to specifically optimize the backwards pass --\n", "removing bottlenecks in the forwards pass results in a fast backwards pass.\n", "\n", "One reason why is that these two passes are just\n", "\"transposes\" of one another,\n", "so they share a lot of properties,\n", "and bottlenecks in one become bottlenecks in the other.\n", "We can choose to optimize either one of the two.\n", "But the forwards pass is under our direct control,\n", "so it's easier for us to reason about.\n", "\n", "Another reason is that the forwards pass is more likely to have bottlenecks.\n", "The forwards pass is a dynamic process,\n", "with each line of Python adding more to the compute graph.\n", "Backwards passes, on the other hand, use a static compute graph,\n", "the one just defined by the forwards pass,\n", "so more optimizations are possible." ] }, { "cell_type": "markdown", "metadata": { "id": "gWiDw0vCvLsI" }, "source": [ "### The optimizer step" ] }, { "cell_type": "markdown", "metadata": { "id": "ndfkzEdnvLsI" }, "source": [ "Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n", "\n", "As with the two passes,\n", "we are still using the CPU\n", "to launch kernels on the GPU.\n", "But now the CPU is looping,\n", "in Python, over the parameters\n", "and applying the ADAM updates rules to each.\n", "\n", "We now know enough to see that\n", "this is not great for our GPU utilization:\n", "there are many areas of gray\n", "in between the colored bars\n", "in the GPU stream in this area.\n", "\n", "In the time it takes CUDA to multiply\n", "thousands of numbers,\n", "Python has not yet finished cleaning up\n", "after its request for that multiplication.\n", "\n", "As of writing in August 2022,\n", "more efficient optimizers are not a stable part of PyTorch (v1.12), but\n", "[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n", "and stable implementations outside of PyTorch.\n", "The standard implementations are in\n", "[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n", "not to be confused with the\n", "[Apex Optimizers Project](https://www.apexoptimizers.com/),\n", "which is a collection of fitness-themed cheetah NFTs." ] }, { "cell_type": "markdown", "metadata": { "id": "WX0jxeafvLsI" }, "source": [ "## Take-aways for PyTorch performance bottleneck troubleshooting" ] }, { "cell_type": "markdown", "metadata": { "id": "CugD-bK2vLsI" }, "source": [ "Our goal here was to learn some basic principles and tools for bottlenecking\n", "the most common issues and the lowest-hanging fruit in PyTorch code." ] }, { "cell_type": "markdown", "metadata": { "id": "SwHwJkVMHYGA" }, "source": [ "\n", "Here's an overview in terms of a \"host\",\n", "generally the CPU,\n", "and a \"device\", here the GPU.\n", "\n", "- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n", "- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n", " - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n", "- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n", "stores information about the contents of tensors,\n", "not just their metadata." ] }, { "cell_type": "markdown", "metadata": { "id": "Gntx28p9cBP5" }, "source": [ "Towards that goal, we viewed the trace to get an understanding of\n", "what's going on inside a PyTorch training step." ] }, { "cell_type": "markdown", "metadata": { "id": "AKvZGPnkeXvq" }, "source": [ "Here's what we've means in terms of troubleshooting bottlenecks.\n", "\n", "We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n", "before the previous kernel finishes.\n", "\n", "Ideally, the CPU is actually getting far _ahead_ of execution\n", "on the GPU.\n", "If the CPU makes it all the way through the backwards pass before the GPU is done,\n", "that's great!\n", "The GPU(s) are the expensive part,\n", "and it's easy to use multiprocessing so that\n", "the CPU has other things to do.\n", "\n", "This helps explain at least one common piece of advice:\n", "the larger our batches are,\n", "the more work the GPU has to do for the same work done by the CPU,\n", "and so the better our utilization will be." ] }, { "cell_type": "markdown", "metadata": { "id": "XMztpa-TccH4" }, "source": [ "We operationalize our desire to never be waiting on the CPU with a simple metric:\n", "**100% GPU utilization**, meaning a kernel is running at all times.\n", "\n", "This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n", "\n", "You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost." ] }, { "cell_type": "markdown", "metadata": { "id": "7kYBygfScR6z" }, "source": [ "Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n", "1. **The CPU is too weak**.\n", "Because so much of the discussion around DNN performance is about GPUs,\n", "it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n", "_Resolution_:\n", "Use nice CPUs, like\n", "[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n", "2. **Too much Python during the `training_step`**.\n", "Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n", "that takes longer than running an entire layer.\n", "_Resolution_:\n", "Look for low utilization areas of the trace\n", "and check what's happening on the CPU at that time\n", "and carefully review the Python code being executed.\n", "3. **Unnecessary Host/Device synchronization**.\n", "If one of your operations depends on the values in a tensor,\n", "like `if xs.mean() >= 0`,\n", "you'll induce a synchronization between\n", "the host and the device and possibly lead\n", "to an expensive and slow copy of data.\n", "_Resolution_:\n", "Replace these operations as much as possible\n", "with purely array-based calculations.\n", "4. **Bottlenecking on the DataLoader**.\n", "In addition to coordinating the work on the GPU,\n", "CPUs often perform heavy data operations,\n", "including communication over the network\n", "and writing to/reading from disk.\n", "These are generally done in parallel to the forwards\n", "and backwards passes,\n", "but if they don't finish before that happens,\n", "they will become the bottleneck.\n", "_Resolution_:\n", "Get better hardware for compute,\n", "memory, and network.\n", "For software solutions, the answer \n", "is a bit more complex and application-dependent.\n", "For generic tips, see\n", "[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n", "in the PyTorch forums.\n", "For techniques in computer vision, see\n", "[the FFCV library](https://github.com/libffcv/ffcv)\n", "and for techniques in NLP, see e.g.\n", "[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n", "and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)." ] }, { "cell_type": "markdown", "metadata": { "id": "i2WYS8bQvLsJ" }, "source": [ "### Further steps in making DNNs go brrrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "T0wW2_lRKfY1" }, "source": [ "It's important to note that utilization\n", "is just an easily measured metric\n", "that can reveal common bottlenecks.\n", "Having high utilization does not automatically mean\n", "that your performance is fully optimized.\n", "\n", "For example,\n", "synchronization events between GPUs\n", "are counted as kernels,\n", "so a deadlock during distributed training\n", "can show up as 100% utilization,\n", "despite literally no useful work occurring.\n", "\n", "Just switching to \n", "double precision floats, `--precision=64`,\n", "will generally lead to much higher utilization.\n", "The GPU operations take longer\n", "for roughly the same amount of CPU effort,\n", "but the added precision brings no benefit.\n", "\n", "In particular, it doesn't make for models\n", "that perform better on our correctness metrics,\n", "like loss and accuracy.\n", "\n", "Another useful yardstick to add\n", "to utilization is examples per second,\n", "which incorporates how quickly the model is processing data examples\n", "and calculating gradients.\n", "\n", "But really,\n", "the gold star is _decrease in loss per second_.\n", "This metric connects model design choices\n", "and hyperparameters with purely engineering concerns,\n", "so it disrespects abstraction barriers\n", "and doesn't generally lead to actionable recommendations,\n", "but it is, in the end, the real goal:\n", "make the loss go down faster so we get better models sooner." ] }, { "cell_type": "markdown", "metadata": { "id": "EFzPsplfdo_o" }, "source": [ "For PyTorch internals abstractly,\n", "see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "For more on performance considerations in PyTorch,\n", "see [Horace He's blog post](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "yq6-S6TC38AY" }, "source": [ "### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n", "\n", "One of the most important features for making\n", "PyTorch run quickly is the\n", "`MultiprocessingDataLoader`,\n", "which executes batching of data in a separate process\n", "from the forwards and backwards passes.\n", "\n", "By default in PyTorch,\n", "this feature is actually turned off,\n", "via the `DataLoader` argument `num_workers`\n", "having a default value of `0`,\n", "but we set the `DEFAULT_NUM_WORKERS`\n", "to a value based on the number of CPUs\n", "available on the system running the code.\n", "\n", "Re-run the profiling cell,\n", "but set `num_workers` to `0`\n", "to turn off multiprocessing.\n", "\n", "Compare and contrast the two traces,\n", "both for total runtime\n", "(see the time axis at the top of the trace)\n", "and for utilization.\n", "\n", "If you're unable to run the profiles,\n", "see the results\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n", "which juxtaposes two traces,\n", "with in-process dataloading on the left and\n", "multiprocessing dataloading on the right." ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test." ] }, { "cell_type": "markdown", "metadata": { "id": "T2i_a5eVeIoA" }, "source": [ "The file below incorrectly implements and then incorrectly tests\n", "a simple PyTorch utility for adding five to every entry of a tensor\n", "and then calculating the sum.\n", "\n", "Even worse, it does it with horrible style!\n", "\n", "The cells below apply our linting checks\n", "(after automatically fixing the formatting)\n", "and run the test.\n", "\n", "Fix all of the lints,\n", "implement the function correctly,\n", "and then implement some basic tests." ] }, { "cell_type": "markdown", "metadata": { "id": "wSon2fB5VVM_" }, "source": [ "- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aYiRvU4HA84t" }, "outputs": [], "source": [ "%%writefile training/fixme.py\n", "import torch\n", "from training import run_experiment\n", "from numpy import *\n", "import random\n", "from pathlib import Path\n", "\n", "\n", "\n", "\n", "def add_five_and_sum(tensor):\n", " # this function is not implemented right,\n", " # but it's supposed to add five to all tensor entries and sum them up\n", " return 1\n", "\n", "def test_add_five_and_sum():\n", " # and this test isn't right either! plus this isn't exactly a docstring\n", " all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n", " all_fives = 5 * all_ones\n", " assert False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EXJpmvuzT1w0" }, "outputs": [], "source": [ "!pre-commit run black --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SRO-oJfdUrcQ" }, "outputs": [], "source": [ "!cat training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jM8NHxVbSEQD" }, "outputs": [], "source": [ "!pre-commit run --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kj0VMBSndtkc" }, "outputs": [], "source": [ "!pytest training/fixme.py" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab05_troubleshooting.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab06/notebooks/lab06_data.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 06: Data Annotation" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How the `IAM` handwriting dataset is structured on disk and how it is processed into an ML-friendly format\n", "- How to setup a [Label Studio](https://labelstud.io/) data annotation server\n", "- Just how messy data really is" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 6\n", "\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "DpvaHz9TEGwV" }, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gsXpeXi2EGwV" }, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "\n", "IFrame(src=\"https://fsdl.me/2022-lab-06-video-embed\", width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "XTkKzEMNR8XZ" }, "source": [ "# `IAMParagraphs`: From annotated data to a PyTorch `Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "3mQLbjuiwZuj" }, "source": [ "We've used the `text_recognizer.data` submodule\n", "and its `LightningDataModule`s -- `IAMLines` and `IAMParagraphs`\n", "for lines and paragraphs of handwritten text\n", "from the\n", "[IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database).\n", "\n", "These classes convert data from a database-friendly format\n", "designed for storage and transfer into the\n", "format our DNNs expect:\n", "PyTorch `Tensor`s.\n", "\n", "In this section,\n", "we'll walk through that process in detail.\n", "\n", "In the following section,\n", "we'll see how data\n", "goes from signals measured in the world\n", "to the format we consume here." ] }, { "cell_type": "markdown", "metadata": { "id": "499c23a6" }, "source": [ "## Dataset structure on disk" ] }, { "cell_type": "markdown", "metadata": { "id": "a3438d2e" }, "source": [ "We begin by downloading the raw data to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "18900eec" }, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM\n", "\n", "iam = IAM()\n", "iam.prepare_data()" ] }, { "cell_type": "markdown", "metadata": { "id": "a332f359" }, "source": [ "The `IAM` dataset is downloaded as zip file\n", "and then unzipped:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d6c44266" }, "outputs": [], "source": [ "from text_recognizer.metadata.iam import DL_DATA_DIRNAME\n", "\n", "\n", "iam_dir = DL_DATA_DIRNAME\n", "!ls {iam_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "8463c2d1" }, "source": [ "The unzipped dataset is not simple a flat directory of files.\n", "\n", "Instead, there are a number of subfolders,\n", "each of which contains a particular type of data or metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "536924f7" }, "outputs": [], "source": [ "iamdb = iam_dir / \"iamdb\"\n", "\n", "!du -h {iamdb}" ] }, { "cell_type": "markdown", "metadata": { "id": "b745a594" }, "source": [ "For example, the `task` folder contains metadata about canonical dataset splits:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "84c21f75" }, "outputs": [], "source": [ "!find {iamdb / \"task\"} | grep \"\\\\.txt$\"" ] }, { "cell_type": "markdown", "metadata": { "id": "mEb0Pdm4vIHe" }, "source": [ "We find the images of handwritten text in the `forms` folder.\n", "\n", "An individual \"datapoint\" in `IAM` is a \"form\",\n", "because the humans whose hands wrote the text were prompted to write on \"forms\",\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "945d5e3a" }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "\n", "form_fn, = !find {iamdb}/forms | grep \".jpg$\" | sort | head -n 1\n", "\n", "print(form_fn)\n", "Image(filename=form_fn, width=\"360\")" ] }, { "cell_type": "markdown", "metadata": { "id": "b9e9e384" }, "source": [ "Meanwhile, the `xml` files contain the data annotations,\n", "written out as structured text:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6add5c5a" }, "outputs": [], "source": [ "xml_fn, = !find {iamdb}/xml | grep \"\\.xml$\" | sort | head -n 1\n", "\n", "!cat {xml_fn} | grep -A 100 \"handwritten-part\" | grep \" str: return f"{KEY}/" + super()._add_prefix(*args, **kwargs) ================================================ FILE: lab06/text_recognizer/callbacks/util.py ================================================ import logging logging.basicConfig(level=logging.WARNING) def check_and_warn(logger, attribute, feature): if not hasattr(logger, attribute): warn_no_attribute(feature, attribute) return True def warn_no_attribute(blocked_feature, missing_attribute): logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.") ================================================ FILE: lab06/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines from .iam_paragraphs import IAMParagraphs from .iam_lines import IAMLines from .fake_images import FakeImageData from .iam_synthetic_paragraphs import IAMSyntheticParagraphs from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs ================================================ FILE: lab06/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab06/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab06/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab06/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab06/text_recognizer/data/fake_images.py ================================================ """A fake image dataset for testing.""" import argparse import torch import torchvision from text_recognizer.data.base_data_module import BaseDataModule _NUM_SAMPLES = 512 _IMAGE_LEN = 28 _NUM_CLASSES = 10 class FakeImageData(BaseDataModule): """Fake images dataset.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.num_samples = self.args.get("num_samples", _NUM_SAMPLES) self.input_dims = (1, self.args.get("image_height", _IMAGE_LEN), self.args.get("image_width", _IMAGE_LEN)) self.num_classes = self.args.get("num_classes", _NUM_CLASSES) self.output_dims = (self.num_classes, 1) self.mapping = list(range(0, self.num_classes)) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--num_samples", type=int, default=_NUM_SAMPLES) parser.add_argument("--num_classes", type=int, default=_NUM_CLASSES) parser.add_argument("--image_height", type=int, default=_IMAGE_LEN) parser.add_argument("--image_width", type=int, default=_IMAGE_LEN) return parser def setup(self, stage: str = None) -> None: fake_dataset = torchvision.datasets.FakeData( size=self.num_samples, image_size=self.input_dims, num_classes=self.output_dims[0], transform=torchvision.transforms.ToTensor(), ) val_size = int(self.num_samples * 0.25) self.data_train, self.data_val, self.data_test = torch.utils.data.random_split( # type: ignore dataset=fake_dataset, lengths=[self.num_samples - 2 * val_size, val_size, val_size] ) ================================================ FILE: lab06/text_recognizer/data/iam.py ================================================ """Class for loading the IAM handwritten text dataset, which encompasses both paragraphs and lines, plus utilities.""" from pathlib import Path from typing import Any, cast, Dict, List, Optional import zipfile from boltons.cacheutils import cachedproperty from defusedxml import ElementTree from PIL import Image, ImageOps import toml from text_recognizer import util from text_recognizer.data.base_data_module import _download_raw_dataset, load_and_print_info import text_recognizer.metadata.iam as metadata from text_recognizer.metadata.iam_paragraphs import NEW_LINE_TOKEN METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME EXTRACTED_DATASET_DIRNAME = metadata.EXTRACTED_DATASET_DIRNAME class IAM: """A dataset of images of handwritten text written on a form underneath a typewritten prompt. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database Images are identified by their "form ID". These IDs are used to separate train, validation and test splits, as keys for dictonaries returning label and image crop region data, and more. The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (LWITLRT): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self): self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self): if self.xml_filenames: return filename = _download_raw_dataset(self.metadata, DL_DATA_DIRNAME) # type: ignore _extract_raw_dataset(filename, DL_DATA_DIRNAME) def load_image(self, id: str) -> Image.Image: """Load and return an image of an entire IAM form. The image is grayscale with white text on black background. This image will have the printed prompt text at the top, above the handwritten text. Images of individual words or lines and of whole paragraphs can be cropped out using the relevant crop region data. """ image = util.read_image_pil(self.form_filenames_by_id[id], grayscale=True) image = ImageOps.invert(image) return image def __repr__(self): """Print info about the dataset.""" info = ["IAM Dataset"] info.append(f"Total Images: {len(self.xml_filenames)}") info.append(f"Total Test Images: {len(self.test_ids)}") info.append(f"Total Paragraphs: {len(self.paragraph_string_by_id)}") num_lines = sum(len(line_regions) for line_regions in self.line_regions_by_id.items()) info.append(f"Total Lines: {num_lines}") return "\n\t".join(info) @cachedproperty def all_ids(self): """A list of all form IDs.""" return sorted([f.stem for f in self.xml_filenames]) @cachedproperty def ids_by_split(self): return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids} @cachedproperty def split_by_id(self): """A dictionary mapping form IDs to their split according to IAM Lines LWITLRT.""" split_by_id = {id_: "train" for id_ in self.train_ids} split_by_id.update({id_: "val" for id_ in self.validation_ids}) split_by_id.update({id_: "test" for id_ in self.test_ids}) return split_by_id @cachedproperty def train_ids(self): """A list of form IDs which are in the IAM Lines LWITLRT training set.""" return list(set(self.all_ids) - (set(self.test_ids) | set(self.validation_ids))) @cachedproperty def test_ids(self): """A list of form IDs from the IAM Lines LWITLRT test set.""" return _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/testset.txt") @property def xml_filenames(self) -> List[Path]: """A list of the filenames of all .xml files, which contain label information.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @cachedproperty def validation_ids(self): """A list of form IDs from IAM Lines LWITLRT validation sets 1 and 2.""" val_ids = _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset1.txt") val_ids.extend(_get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset2.txt")) return val_ids @property def form_filenames(self) -> List[Path]: """A list of the filenames of all .jpg files, which contain images of IAM forms.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def xml_filenames_by_id(self): """A dictionary mapping form IDs to their XML label information files.""" return {filename.stem: filename for filename in self.xml_filenames} @property def form_filenames_by_id(self): """A dictionary mapping form IDs to their JPEG images.""" return {filename.stem: filename for filename in self.form_filenames} @cachedproperty def line_strings_by_id(self): """A dict mapping an IAM form id to its list of line texts.""" return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def line_regions_by_id(self): """A dict mapping an IAM form id to its list of line image crop regions.""" return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def paragraph_string_by_id(self): """A dict mapping an IAM form id to its paragraph text.""" return {id: NEW_LINE_TOKEN.join(line_strings) for id, line_strings in self.line_strings_by_id.items()} @cachedproperty def paragraph_region_by_id(self): """A dict mapping an IAM form id to its paragraph image crop region.""" return { id: { "x1": min(region["x1"] for region in line_regions), "y1": min(region["y1"] for region in line_regions), "x2": max(region["x2"] for region in line_regions), "y2": max(region["y2"] for region in line_regions), } for id, line_regions in self.line_regions_by_id.items() } def _extract_raw_dataset(filename: Path, dirname: Path) -> None: print("Extracting IAM data") with util.temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zip_file: zip_file.extractall() def _get_ids_from_lwitlrt_split_file(filename: str) -> List[str]: """Get the ids from Large Writer Independent Text Line Recognition Task (LWITLRT) data split file.""" with open(filename, "r") as f: line_ids_str = f.read() line_ids = line_ids_str.split("\n") page_ids = list({"-".join(line_id.split("-")[:2]) for line_id in line_ids if line_id}) return page_ids def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace " with ".""" xml_line_elements = _get_line_elements_from_xml_file(filename) return [_get_text_from_xml_element(el) for el in xml_line_elements] def _get_text_from_xml_element(xml_element: Any) -> str: """Extract text from any XML element.""" return xml_element.attrib["text"].replace(""", '"') def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get the line region dict for each line.""" xml_line_elements = _get_line_elements_from_xml_file(filename) line_regions = [ cast(Dict[str, int], _get_region_from_xml_element(xml_elem=el, xml_path="word/cmp")) for el in xml_line_elements ] assert any(region is not None for region in line_regions), "Line regions cannot be None" # next_line_region["y1"] - prev_line_region["y2"] can be negative due to overlapping characters line_gaps_y = [ max(next_line_region["y1"] - prev_line_region["y2"], 0) for next_line_region, prev_line_region in zip(line_regions[1:], line_regions[:-1]) ] post_line_gaps_y = line_gaps_y + [2 * metadata.LINE_REGION_PADDING] pre_line_gaps_y = [2 * metadata.LINE_REGION_PADDING] + line_gaps_y return [ { "x1": region["x1"] - metadata.LINE_REGION_PADDING, "x2": region["x2"] + metadata.LINE_REGION_PADDING, "y1": region["y1"] - min(metadata.LINE_REGION_PADDING, pre_line_gaps_y[i] // 2), "y2": region["y2"] + min(metadata.LINE_REGION_PADDING, post_line_gaps_y[i] // 2), } for i, region in enumerate(line_regions) ] def _get_line_elements_from_xml_file(filename: str) -> List[Any]: """Get all line xml elements from xml file.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec return xml_root_element.findall("handwritten-part/line") def _get_region_from_xml_element(xml_elem: Any, xml_path: str) -> Optional[Dict[str, int]]: """ Get region from input xml element. The region is downsampled because the stored images are also downsampled. Parameters ---------- xml_elem xml element can be a line or word element with x, y, width, and height attributes xml_path should be "word/cmp" if xml_elem is a line element, else "cmp" """ unit_elements = xml_elem.findall(xml_path) if not unit_elements: return None return { "x1": min(int(el.attrib["x"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y1": min(int(el.attrib["y"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "x2": max(int(el.attrib["x"]) + int(el.attrib["width"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y2": max(int(el.attrib["y"]) + int(el.attrib["height"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, } if __name__ == "__main__": load_and_print_info(IAM) ================================================ FILE: lab06/text_recognizer/data/iam_lines.py ================================================ """A dataset of lines of handwritten text derived from the IAM dataset.""" import argparse import json from pathlib import Path from typing import Sequence import numpy as np from PIL import Image, ImageFile from text_recognizer import util from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.line import IAMLineStem ImageFile.LOAD_TRUNCATED_IMAGES = True PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR class IAMLines(BaseDataModule): """Lines of text pulled from the IAM Handwriting database.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true") == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = IAMLineStem() self.trainval_transform = IAMLineStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if PROCESSED_DATA_DIRNAME.exists(): return print("Cropping IAM line regions...") iam = IAM() iam.prepare_data() crops_train, labels_train = generate_line_crops_and_labels(iam, "train") crops_val, labels_val = generate_line_crops_and_labels(iam, "val") crops_test, labels_test = generate_line_crops_and_labels(iam, "test") shapes = np.array([crop.size for crop in crops_train + crops_val + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] print("Saving images, labels, and statistics...") save_images_and_labels(crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_val, labels_val, "val", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt", "w") as file: file.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt") as file: max_aspect_ratio = float(file.read()) image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) assert image_width <= metadata.IMAGE_WIDTH if stage == "fit" or stage is None: x_train, labels_train = load_processed_crops_and_labels("train", PROCESSED_DATA_DIRNAME) y_train = convert_strings_to_labels(labels_train, self.inverse_mapping, length=self.output_dims[0]) self.data_train = BaseDataset(x_train, y_train, transform=self.trainval_transform) x_val, labels_val = load_processed_crops_and_labels("val", PROCESSED_DATA_DIRNAME) y_val = convert_strings_to_labels(labels_val, self.inverse_mapping, length=self.output_dims[0]) self.data_val = BaseDataset(x_val, y_val, transform=self.trainval_transform) # quick check: do we have the right sequence lengths? assert self.output_dims[0] >= max([len(_) for _ in labels_train]) + 2 # Add 2 for start/end tokens. assert self.output_dims[0] >= max([len(_) for _ in labels_val]) + 2 # Add 2 for start/end tokens. if stage == "test" or stage is None: x_test, labels_test = load_processed_crops_and_labels("test", PROCESSED_DATA_DIRNAME) y_test = convert_strings_to_labels(labels_test, self.inverse_mapping, length=self.output_dims[0]) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) assert self.output_dims[0] >= max([len(_) for _ in labels_test]) + 2 def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Lines Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def generate_line_crops_and_labels(iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR): """Create both cropped lines and associated labels from IAM, with resizing by default""" crops, labels = [], [] for iam_id in iam.ids_by_split[split]: labels += iam.line_strings_by_id[iam_id] image = iam.load_image(iam_id) for line in iam.line_regions_by_id[iam_id]: coords = [line[point] for point in ["x1", "y1", "x2", "y2"]] crop = image.crop(coords) crop = resize_image(crop, scale_factor=scale_factor) crops.append(crop) assert len(crops) == len(labels) return crops, labels def save_images_and_labels(crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path): (data_dirname / split).mkdir(parents=True, exist_ok=True) with open(data_dirname / split / "_labels.json", "w") as f: json.dump(labels, f) for ind, crop in enumerate(crops): crop.save(data_dirname / split / f"{ind}.png") def load_processed_crops_and_labels(split: str, data_dirname: Path): """Load line crops and labels for given split from processed directory.""" crops = load_processed_line_crops(split, data_dirname) labels = load_processed_line_labels(split, data_dirname) assert len(crops) == len(labels) return crops, labels def load_processed_line_crops(split: str, data_dirname: Path): """Load line crops for given split from processed directory.""" crop_filenames = sorted((data_dirname / split).glob("*.png"), key=lambda filename: int(Path(filename).stem)) crops = [util.read_image_pil(filename, grayscale=True) for filename in crop_filenames] return crops def load_processed_line_labels(split: str, data_dirname: Path): """Load line labels for given split from processed directory.""" with open(data_dirname / split / "_labels.json") as file: labels = json.load(file) return labels if __name__ == "__main__": load_and_print_info(IAMLines) ================================================ FILE: lab06/text_recognizer/data/iam_original_and_synthetic_paragraphs.py ================================================ """IAM Original and Synthetic Paragraphs Dataset class.""" import argparse from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs class IAMOriginalAndSyntheticParagraphs(BaseDataModule): """A concatenation of original and synthetic IAM paragraph datasets.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.iam_paragraphs = IAMParagraphs(args) self.iam_syn_paragraphs = IAMSyntheticParagraphs(args) self.input_dims = self.iam_paragraphs.input_dims self.output_dims = self.iam_paragraphs.output_dims self.mapping = self.iam_paragraphs.mapping self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") IAMSyntheticParagraphs.add_to_argparse(parser) return parser def prepare_data(self, *args, **kwargs) -> None: self.iam_paragraphs.prepare_data() self.iam_syn_paragraphs.prepare_data() def setup(self, stage: str = None) -> None: self.iam_paragraphs.setup(stage) self.iam_syn_paragraphs.setup(stage) if stage == "fit" or stage is None: self.data_train = ConcatDataset([self.iam_paragraphs.data_train, self.iam_syn_paragraphs.data_train]) self.data_val = self.iam_paragraphs.data_val if stage == "test" or stage is None: self.data_test = self.iam_paragraphs.data_test def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data if __name__ == "__main__": load_and_print_info(IAMOriginalAndSyntheticParagraphs) ================================================ FILE: lab06/text_recognizer/data/iam_paragraphs.py ================================================ """IAM Paragraphs Dataset class.""" import argparse import json from pathlib import Path from typing import Callable, Dict, Optional, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.paragraph import ParagraphStem IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME class IAMParagraphs(BaseDataModule): """IAM Handwriting database paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true").lower() == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = ParagraphStem() self.trainval_transform = ParagraphStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): return rank_zero_info( "IAMParagraphs.prepare_data: Cropping IAM paragraph regions and saving them along with labels..." ) iam = IAM() iam.prepare_data() properties = {} for split in ["train", "val", "test"]: crops, labels = get_paragraph_crops_and_labels(iam=iam, split=split) save_crops_and_labels(crops=crops, labels=labels, split=split) properties.update( { id_: { "crop_shape": crops[id_].size[::-1], "label_length": len(label), "num_lines": _num_lines(label), } for id_, label in labels.items() } ) with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: def _load_dataset(split: str, transform: Callable) -> BaseDataset: crops, labels = load_processed_crops_and_labels(split) Y = convert_strings_to_labels(strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]) return BaseDataset(crops, Y, transform=transform) rank_zero_info(f"IAMParagraphs.setup({stage}): Loading IAM paragraph regions and lines...") validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) if stage == "fit" or stage is None: self.data_train = _load_dataset(split="train", transform=self.trainval_transform) self.data_val = _load_dataset(split="val", transform=self.transform) if stage == "test" or stage is None: self.data_test = _load_dataset(split="test", transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def validate_input_and_output_dimensions( input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] ) -> None: """Validate input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] # Add 2 because of start and end tokens assert output_dims is not None and output_dims[0] >= properties["label_length"]["max"] + 2 def get_paragraph_crops_and_labels( iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR ) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: """Create IAM paragraph crops and labels for a given split, with resizing.""" crops = {} labels = {} for iam_id in iam.ids_by_split[split]: image = iam.load_image(iam_id) para_region = iam.paragraph_region_by_id[iam_id] crops[iam_id] = image.crop([para_region[_] for _ in ["x1", "y1", "x2", "y2"]]) crops[iam_id] = resize_image(crops[iam_id], scale_factor=scale_factor) labels[iam_id] = iam.paragraph_string_by_id[iam_id] assert len(crops) == len(labels) return crops, labels def save_crops_and_labels(crops: Dict[str, Image.Image], labels: Dict[str, str], split: str): """Save crops, labels and shapes of crops of a split.""" (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with open(_labels_filename(split), "w") as f: json.dump(labels, f, indent=4) for id_, crop in crops.items(): crop.save(_crop_filename(id_, split)) def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: """Load processed crops and labels for given split.""" with open(_labels_filename(split), "r") as f: labels = json.load(f) sorted_ids = sorted(labels.keys()) ordered_crops = [Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids] ordered_labels = [labels[id_] for id_ in sorted_ids] assert len(ordered_crops) == len(ordered_labels) return ordered_crops, ordered_labels def get_dataset_properties() -> dict: """Return properties describing the overall dataset.""" with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: properties = json.load(f) def _get_property_values(key: str) -> list: return [_[key] for _ in properties.values()] crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count(NEW_LINE_TOKEN) + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs) ================================================ FILE: lab06/text_recognizer/data/iam_synthetic_paragraphs.py ================================================ """IAM Synthetic Paragraphs Dataset class.""" import argparse import random from typing import Any, Callable, List, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info import torch from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( generate_line_crops_and_labels, load_processed_line_crops, load_processed_line_labels, save_images_and_labels, ) from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.util import convert_strings_to_labels import text_recognizer.metadata.iam_synthetic_paragraphs as metadata NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME DATASET_LEN = metadata.DATASET_LEN class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database synthetic paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.line_crops = None self.line_labels = None self.dataset_len = self.args.get("dataset_len", DATASET_LEN) def prepare_data(self, *args, **kwargs) -> None: """ Prepare IAM lines such that they can be used to generate synthetic paragraphs dataset in setup(). This method is IAMLines.prepare_data + resizing of line crops. """ if PROCESSED_DATA_DIRNAME.exists(): return rank_zero_info( "IAMSyntheticParagraphs.prepare_data: preparing IAM lines for synthetic IAM paragraph creation..." ) iam = IAM() iam.prepare_data() for split in ["train"]: # synthetic dataset is only used in training phase rank_zero_info(f"Cropping IAM line regions and loading labels for {split} data split...") crops, labels = generate_line_crops_and_labels(iam, split) save_images_and_labels(crops, labels, split, PROCESSED_DATA_DIRNAME) def setup(self, stage: str = None) -> None: rank_zero_info(f"IAMSyntheticParagraphs.setup({stage}): Loading train IAM paragraph regions and lines...") if stage == "fit" or stage is None: self._load_processed_crops_and_labels() self.data_train = IAMSyntheticParagraphsDataset( line_crops=self.line_crops, line_labels=self.line_labels, dataset_len=self.dataset_len, inverse_mapping=self.inverse_mapping, input_dims=self.input_dims, output_dims=self.output_dims, transform=self.trainval_transform, ) def _load_processed_crops_and_labels(self): if self.line_crops is None: self.line_crops = load_processed_line_crops("train", PROCESSED_DATA_DIRNAME) if self.line_labels is None: self.line_labels = load_processed_line_labels("train", PROCESSED_DATA_DIRNAME) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Synthetic Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def add_to_argparse(parser): parser.add_argument("--dataset_len", type=int, default=DATASET_LEN) return parser class IAMSyntheticParagraphsDataset(torch.utils.data.Dataset): """Dataset of synthetic paragraphs built out of individual IAM lines.""" def __init__( self, line_crops: List[Image.Image], line_labels: List[str], dataset_len: int, inverse_mapping: dict, input_dims: Tuple[int, ...], output_dims: Tuple[int, ...], transform: Callable = None, ) -> None: super().__init__() self.line_crops = line_crops self.line_labels = line_labels assert len(self.line_crops) == len(self.line_labels) self.ids = list(range(len(self.line_labels))) self.dataset_len = dataset_len self.inverse_mapping = inverse_mapping self.input_dims = input_dims self.output_dims = output_dims self.transform = transform self.min_num_lines, self.max_num_lines = 1, 15 self.seed_set = False def __len__(self) -> int: """Return length of the dataset.""" return self.dataset_len def _set_seed(self, seed): if not self.seed_set: print(f"Setting seed to {seed} for worker {torch.utils.data.get_worker_info()}") random.seed(seed) self.seed_set = True def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a random paragraph, using the first index as a seed.""" # Since shuffle is True for train dataloaders, the first index will be different on different GPUs self._set_seed(index) num_lines = random.randint(self.min_num_lines, self.max_num_lines) indices = random.sample(self.ids, k=num_lines) while True: datum = join_line_crops_to_form_paragraph([self.line_crops[i] for i in indices]) labels = NEW_LINE_TOKEN.join([self.line_labels[i] for i in indices]) if ( (len(labels) <= self.output_dims[0] - 2) and (datum.height <= self.input_dims[1]) and (datum.width <= self.input_dims[2]) ): break indices = indices[:-1] if self.transform is not None: datum = self.transform(datum) length = self.output_dims[0] target = convert_strings_to_labels(strings=[labels], mapping=self.inverse_mapping, length=length)[0] return datum, target def join_line_crops_to_form_paragraph(line_crops: Sequence[Image.Image]) -> Image.Image: """Horizontally stack line crops and return a single image forming the paragraph.""" crop_shapes = np.array([_.size[::-1] for _ in line_crops]) para_height = crop_shapes[:, 0].sum() para_width = crop_shapes[:, 1].max() para_image = Image.new(mode="L", size=(para_width, para_height), color=0) current_height = 0 for line_crop in line_crops: para_image.paste(line_crop, box=(0, current_height)) current_height += line_crop.height return para_image if __name__ == "__main__": load_and_print_info(IAMSyntheticParagraphs) ================================================ FILE: lab06/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab06/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab06/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab06/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel from .transformer import TransformerLitModel ================================================ FILE: lab06/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy from .metrics import CharacterErrorRate OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def add_on_first_batch(self, metrics, outputs, batch_idx): if batch_idx == 0: outputs.update(metrics) def add_on_logged_batches(self, metrics, outputs): if self.is_logged_batch: outputs.update(metrics) def is_logged_batch(self): if self.trainer is None: return False else: return self.trainer._logger_connector.should_update_logs class BaseImageToTextLitModel(BaseLitModel): # pylint: disable=too-many-ancestors """Base class for ImageToText models in PyTorch Lightning.""" def __init__(self, model, args: argparse.Namespace = None): super().__init__(model, args) self.model = model self.args = vars(args) if args is not None else {} self.inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} self.start_index = self.inverse_mapping[""] self.end_index = self.inverse_mapping[""] self.padding_index = self.inverse_mapping["

"] self.ignore_tokens = [self.start_index, self.end_index, self.padding_index] self.val_cer = CharacterErrorRate(self.ignore_tokens) self.test_cer = CharacterErrorRate(self.ignore_tokens) ================================================ FILE: lab06/text_recognizer/lit_models/metrics.py ================================================ """Special-purpose metrics for tracking our model performance.""" from typing import Sequence import torch import torchmetrics class CharacterErrorRate(torchmetrics.CharErrorRate): """Character error rate metric, allowing for tokens to be ignored.""" def __init__(self, ignore_tokens: Sequence[int], *args): super().__init__(*args) self.ignore_tokens = set(ignore_tokens) def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()] targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()] super().update(preds_l, targets_l) def test_character_error_rate(): metric = CharacterErrorRate([0, 1]) X = torch.tensor( [ [0, 2, 2, 3, 3, 1], # error will be 0 [0, 2, 1, 1, 1, 1], # error will be .75 [0, 2, 2, 4, 4, 1], # error will be .5 ] ) Y = torch.tensor( [ [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], ] ) metric(X, Y) assert metric.compute() == sum([0, 0.75, 0.5]) / 3 if __name__ == "__main__": test_character_error_rate() ================================================ FILE: lab06/text_recognizer/lit_models/transformer.py ================================================ """An encoder-decoder Transformer model""" from typing import List, Sequence import torch from .base import BaseImageToTextLitModel from .util import replace_after class TransformerLitModel(BaseImageToTextLitModel): """ Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module. The module must implement an encode and decode method, and the forward method should be the forward pass during production inference. """ def __init__(self, model, args=None): super().__init__(model, args) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index) def forward(self, x): return self.model(x) def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Uses provided sequence y as guide for non-autoregressive encoding-decoding of x. Parameters ---------- x Batch of images to be encoded. See self.model.encode for shape information. y Batch of ground truth output sequences. Returns ------- torch.Tensor (B, C, Sy) logits """ x = self.model.encode(x) output = self.model.decode(x, y) # (Sy, B, C) return output.permute(1, 2, 0) # (B, C, Sy) def training_step(self, batch, batch_idx): x, y = batch logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("train/loss", loss) outputs = {"loss": loss} if self.is_logged_batch(): preds = self.get_preds(logits) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs}) return outputs def validation_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("test/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def map(self, ks: Sequence[int], ignore: bool = True) -> str: """Maps an iterable of integers to a string using the lit model's mapping.""" if ignore: return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens]) else: return "".join([self.mapping[k] for k in ks]) def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]: """Maps a list of lists of integers to a list of strings using the lit model's mapping.""" return [self.map(k, ignore) for k in ks] def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor: """Converts logit-like Tensors into prediction indices, optionally overwritten after end token index. Parameters ---------- logitlikes (B, C, Sy) Tensor with classes as second dimension. The largest value is the one whose index we will return. Logits, logprobs, and probs are all acceptable. replace_after_end Whether to replace values after the first appearance of the end token with the padding token. Returns ------- torch.Tensor (B, Sy) Tensor of integers in [0, C-1] representing predictions. """ raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy) if replace_after_end: return replace_after(raw, self.end_index, self.padding_index) # (B, Sy) else: return raw # (B, Sy) ================================================ FILE: lab06/text_recognizer/lit_models/util.py ================================================ from typing import Union import torch def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: """Return indices of first appearance of element in x, collapsing along dim. Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 Parameters ---------- x One or two-dimensional Tensor to search for element. element Item to search for inside x. dim Dimension of Tensor to collapse over. Returns ------- torch.Tensor Indices where element occurs in x. If element is not found, return length of x along dim. One dimension smaller than x. Raises ------ ValueError if x is not a 1 or 2 dimensional Tensor Examples -------- >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3) tensor([2, 1, 3, 0]) >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0) tensor(0) """ if x.dim() > 2 or x.dim() == 0: raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}") matches = x == element first_appearance_mask = (matches.cumsum(dim) == 1) & matches does_match, match_index = first_appearance_mask.max(dim) first_inds = torch.where(does_match, match_index, x.shape[dim]) return first_inds def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor: """Replace all values in each row of 2d Tensor x after the first appearance of element with replace. Parameters ---------- x Two-dimensional Tensor (shape denoted (B, S)) to replace values in. element Item to search for inside x. replace Item that replaces entries that appear after element. Returns ------- outs New Tensor of same shape as x with values after element replaced. Examples -------- >>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4) tensor([[1, 2, 3], [2, 3, 4], [1, 1, 1], [3, 4, 4]]) """ first_appearances = first_appearance(x, element, dim=1) # (B,) indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,) outs = torch.where( indices[None, :] <= first_appearances[:, None], # if index is before first appearance x, # return the value from x replace, # otherwise, return the replacement value ) return outs # (B, S) ================================================ FILE: lab06/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab06/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab06/text_recognizer/metadata/iam.py ================================================ import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates ================================================ FILE: lab06/text_recognizer/metadata/iam_lines.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines" IMAGE_SCALE_FACTOR = 2 CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (89, 1) MAPPING = emnist.MAPPING ================================================ FILE: lab06/text_recognizer/metadata/iam_paragraphs.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN] # must match IMAGE_SCALE_FACTOR for IAMLines to be compatible with synthetic paragraphs IMAGE_SCALE_FACTOR = 2 IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640 IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) MAX_LABEL_LENGTH = 682 DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) ================================================ FILE: lab06/text_recognizer/metadata/iam_synthetic_paragraphs.py ================================================ import text_recognizer.metadata.iam_paragraphs as iam_paragraphs import text_recognizer.metadata.shared as shared NEW_LINE_TOKEN = iam_paragraphs.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs" EXPECTED_BATCH_SIZE = 64 EXPECTED_GPUS = 8 EXPECTED_STEPS = 40 # set the dataset's length based on parameters during typical training DATASET_LEN = EXPECTED_BATCH_SIZE * EXPECTED_GPUS * EXPECTED_STEPS ================================================ FILE: lab06/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab06/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab06/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple from .resnet_transformer import ResnetTransformer from .line_cnn_transformer import LineCNNTransformer ================================================ FILE: lab06/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab06/text_recognizer/models/line_cnn.py ================================================ """Basic building blocks for convolutional models over lines of text.""" import argparse import math from typing import Any, Dict, Tuple, Union import torch from torch import nn import torch.nn.functional as F # Common type hints Param2D = Union[int, Tuple[int, int]] CONV_DIM = 32 FC_DIM = 512 FC_DROPOUT = 0.2 WINDOW_WIDTH = 16 WINDOW_STRIDE = 8 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__( self, input_channels: int, output_channels: int, kernel_size: Param2D = 3, stride: Param2D = 1, padding: Param2D = 1, ) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class LineCNN(nn.Module): """ Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits """ def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.args = vars(args) if args is not None else {} self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] _C, H, _W = data_config["input_dims"] conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) # Input is (1, H, W) self.convs = nn.Sequential( ConvBlock(1, conv_dim), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim, stride=2), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim * 2, stride=2), ConvBlock(conv_dim * 2, conv_dim * 2), ConvBlock(conv_dim * 2, conv_dim * 4, stride=2), ConvBlock(conv_dim * 4, conv_dim * 4), ConvBlock( conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0 ), ) self.fc1 = nn.Linear(fc_dim, fc_dim) self.dropout = nn.Dropout(fc_dropout) self.fc2 = nn.Linear(fc_dim, self.num_classes) self._init_weights() def _init_weights(self): """ Initialize weights in a better way than default. See https://github.com/pytorch/pytorch/issues/18182 """ for m in self.modules(): if type(m) in { nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear, }: nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu") if m.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(m.bias, -bound, bound) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the LineCNN to a black-and-white input image. Parameters ---------- x (B, 1, H, W) input image Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and self.window_width C is self.num_classes """ _B, _C, _H, _W = x.shape x = self.convs(x) # (B, FC_DIM, 1, Sx) x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM) x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM) x = self.dropout(x) x = self.fc2(x) # (B, S, C) x = x.permute(0, 2, 1) # -> (B, C, S) if self.limit_output_length: x = x[:, :, : self.output_length] return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab06/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab06/text_recognizer/models/line_cnn_transformer.py ================================================ """Model that combines a LineCNN with a Transformer model for text prediction.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .line_cnn import LineCNN from .transformer_util import generate_square_subsequent_mask, PositionalEncoding TF_DIM = 256 TF_FC_DIM = 256 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 class LineCNNTransformer(nn.Module): """Process the line through a CNN and process the resulting sequence with a Transformer decoder.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # Instantiate LineCNN with "num_classes" set to self.dim data_config_for_line_cnn = {**data_config} data_config_for_line_cnn["mapping"] = list(range(self.dim)) self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args) # LineCNN outputs (B, E, S) log probs, with E == dim self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.pos_encoder = PositionalEncoding(d_model=self.dim) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (Sx, B, E) logits """ x = self.line_cnn(x) # (B, E, Sx) x = x * math.sqrt(self.dim) x = x.permute(2, 0, 1) # (Sx, B, E) x = self.pos_encoder(x) # (Sx, B, E) return x def decode(self, x, y): """Decode a batch of encoded images x using preceding ground truth y. Parameters ---------- x (Sx, B, E) image encoded as a sequence y (B, Sy) with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) logits """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output def forward(self, x: torch.Tensor) -> torch.Tensor: """Predict sequences of tokens from input images auto-regressively. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1:] # Set the last output token # Set all tokens after end token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) @staticmethod def add_to_argparse(parser): LineCNN.add_to_argparse(parser) parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab06/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab06/text_recognizer/models/resnet_transformer.py ================================================ """Model combining a ResNet with a Transformer for image-to-sequence tasks.""" import argparse import math from typing import Any, Dict import torch from torch import nn import torchvision from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage TF_DIM = 256 TF_FC_DIM = 1024 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 RESNET_DIM = 512 # hard-coded class ResnetTransformer(nn.Module): """Pass an image through a Resnet and decode the resulting embedding with a Transformer.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) self.mapping = data_config["mapping"] inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # ## Encoder part - should output vector sequence of length self.dim per sample resnet = torchvision.models.resnet18(weights=None) self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) # encoder_projection will output (B, dim, _H, _W) logits self.enc_pos_encoder = PositionalEncodingImage( d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] ) # Max (Ho, Wo) # ## Decoder part self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def forward(self, x: torch.Tensor) -> torch.Tensor: """Autoregressively produce sequences of labels from input images. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- output_tokens (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1] # Set the last output token # Early stopping of prediction loop to speed up prediction if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): break # Set all tokens after end or padding token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") if self.encoder_projection.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.encoder_projection.bias, -bound, bound) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- (Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps """ _B, C, _H, _W = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo return x def decode(self, x, y): """Decode a batch of encoded images x with guiding sequences y. During autoregressive inference, the guiding sequence will be previous predictions. During training, the guiding sequence will be the ground truth. Parameters ---------- x (Sx, B, E) images encoded as sequences of embeddings y (B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) batch of logit sequences """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.dec_pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output @staticmethod def add_to_argparse(parser): parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab06/text_recognizer/models/transformer_util.py ================================================ """Position Encoding and other utilities for Transformers.""" import math import torch from torch import Tensor import torch.nn as nn class PositionalEncodingImage(nn.Module): """ Module used to add 2-D positional encodings to the feature-map produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor: pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """pytorch.nn.module.forward""" # x.shape = (B, d_model, H, W) assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding(torch.nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None: super().__init__() self.dropout = torch.nn.Dropout(p=dropout) pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_len: int) -> torch.Tensor: pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: # x.shape = (S, B, d_model) assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) def generate_square_subsequent_mask(size: int) -> torch.Tensor: """Generate a triangular (size, size) mask.""" mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask ================================================ FILE: lab06/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab06/text_recognizer/stems/line.py ================================================ import random from PIL import Image from torchvision import transforms import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.image import ImageStem class LineStem(ImageStem): """A stem for handling images containing a line of text.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.5, 1)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "translate": (0, 0.05), "scale": (0.4, 1.1), "shear": (-40, 50), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } if augment: self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] ) class IAMLineStem(ImageStem): """A stem for handling images containing lines of text from the IAMLines dataset.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() def embed_crop(crop, augment=augment): # crop is PIL.image of dtype="L" (so values range from 0 -> 255) image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) # Resize crop crop_width, crop_height = crop.size new_crop_height = metadata.IMAGE_HEIGHT new_crop_width = int(new_crop_height * (crop_width / crop_height)) if augment: # Add random stretching new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR) # Embed in the image x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) y = metadata.IMAGE_HEIGHT - new_crop_height image.paste(crop_resized, (x, y)) return image if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.8, 1.6)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 1, "shear": (-30, 20), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } pil_transforms_list = [transforms.Lambda(embed_crop)] if augment: pil_transforms_list += [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] self.pil_transforms = transforms.Compose(pil_transforms_list) ================================================ FILE: lab06/text_recognizer/stems/paragraph.py ================================================ """IAMParagraphs Stem class.""" import torchvision.transforms as transforms import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.image import ImageStem IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH class ParagraphStem(ImageStem): """A stem for handling images that contain a paragraph of text.""" def __init__( self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None, random_perspective_kwargs=None, gaussian_blur_kwargs=None, sharpness_kwargs=None, ): super().__init__() if not augment: self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)]) else: if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "shear": 6, "scale": (0.95, 1), "interpolation": transforms.InterpolationMode.BILINEAR, } if random_perspective_kwargs is None: random_perspective_kwargs = { "distortion_scale": 0.2, "p": 0.5, "interpolation": transforms.InterpolationMode.BILINEAR, } if gaussian_blur_kwargs is None: gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} if sharpness_kwargs is None: sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} # IMAGE_SHAPE is (576, 640) self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomCrop( size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant" ), transforms.RandomAffine(**random_affine_kwargs), transforms.RandomPerspective(**random_perspective_kwargs), transforms.GaussianBlur(**gaussian_blur_kwargs), transforms.RandomAdjustSharpness(**sharpness_kwargs), ] ) ================================================ FILE: lab06/text_recognizer/tests/test_callback_utils.py ================================================ """Tests for the text_recognizer.callbacks.util module.""" import random import string import tempfile import pytorch_lightning as pl from text_recognizer.callbacks.util import check_and_warn def test_check_and_warn_simple(): """Test the success and failure in the case of a simple class we control.""" class Foo: pass # a class with no special attributes letters = string.ascii_lowercase random_attribute = "".join(random.choices(letters, k=10)) assert check_and_warn(Foo(), random_attribute, "random feature") assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects") def test_check_and_warn_tblogger(): """Test that we return a truthy value when trying to log tables with TensorBoard. We added check_and_warn in order to prevent a crash if this happens. """ tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory()) assert check_and_warn(tblogger, "log_table", "tables") def test_check_and_warn_wandblogger(): """Test that we return a falsy value when we try to log tables with W&B. In adding check_and_warn, we don't want to block the feature in the happy path. """ wandblogger = pl.loggers.WandbLogger(anonymous=True) assert not check_and_warn(wandblogger, "log_table", "tables") ================================================ FILE: lab06/text_recognizer/tests/test_iam.py ================================================ """Test for data.iam module.""" from text_recognizer.data.iam import IAM def test_iam_parsed_lines(): """Tests that we retrieve the same number of line labels and line image cropregions.""" iam = IAM() iam.prepare_data() for iam_id in iam.all_ids: assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id]) def test_iam_data_splits(): """Fails when any identifiers are shared between training, test, or validation.""" iam = IAM() iam.prepare_data() assert not set(iam.train_ids) & set(iam.validation_ids) assert not set(iam.train_ids) & set(iam.test_ids) assert not set(iam.validation_ids) & set(iam.test_ids) ================================================ FILE: lab06/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab06/training/__init__.py ================================================ ================================================ FILE: lab06/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import callbacks as cb from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--wandb", action="store_true", default=False, help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.", ) parser.add_argument( "--profile", action="store_true", default=False, help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.", ) parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.loss == "transformer": lit_model_class = lit_models.TransformerLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" if goldstar_metric == "validation/cer": filename_format += "-validation.cer={validation/cer:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.wandb: logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train") logger.watch(model, log_freq=max(100, args.log_every_n_steps)) logger.log_hyperparams(vars(args)) experiment_dir = logger.experiment.dir callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) if args.wandb and args.loss in ("transformer",): callbacks.append(cb.ImageToTextLogger()) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) if args.profile: sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0) profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir) profiler.STEP_FUNCTIONS = {"training_step"} # only profile training else: profiler = pl.profiler.PassThroughProfiler() trainer.profiler = profiler trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") if args.wandb: rank_zero_info("Best model also uploaded to W&B ") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab06/training/tests/test_memorize_iam.sh ================================================ #!/bin/bash set -uo pipefail set +e # tests whether we can achieve a criterion loss # on a single batch within a certain number of epochs FAILURE=false # constants and CLI args set by aiming for <5 min test on commodity GPU, # including data download step MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash CRITERION="${2:-1.0}" # train on GPU if it's available GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))') python ./training/run_experiment.py \ --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \ --augment_data false --tf_dropout 0.0 \ --gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \ --log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true if [ "$FAILURE" = true ]; then echo "Memorization test failed at loss criterion $CRITERION" exit 1 fi echo "Memorization test passed at loss criterion $CRITERION" exit 0 ================================================ FILE: lab06/training/tests/test_run_experiment.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false echo "running full loop test with CNN on fake data" python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true echo "running fast_dev_run test of real model class on real data" python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \ --fast_dev_run --num_sanity_val_steps 0 \ --num_workers 1 || FAILURE=true if [ "$FAILURE" = true ]; then echo "Test for run_experiment.py failed" exit 1 fi echo "Tests for run_experiment.py passed" exit 0 ================================================ FILE: lab06/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: lab07/.flake8 ================================================ [flake8] select = ANN,B,B9,BLK,C,D,E,F,I,S,W # only check selected error codes max-complexity = 12 # C9 - flake8 McCabe Complexity checker -- threshold max-line-length = 120 # E501 - flake8 -- line length too long, actually handled by black extend-ignore = # E W - flake8 PEP style check E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks # S - flake8-bandit safety check S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password # ANN - flake8-annotations type annotation check ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some # D1 - flake8-docstrings docstring style check D100,D102,D103,D104,D105, # missing docstrings # D2 D4 - flake8-docstrings docstring style check D200,D205,D400,D401, # whitespace issues and first line content # DAR - flake8-darglint docstring correctness check DAR103, # mismatched or missing type in docstring application-import-names = app_gradio,text_recognizer,tests,training # flake8-import-order: which names are first party? import-order-style = google # flake8-import-order: which import order style guide do we use? docstring-convention = numpy # flake8-docstrings: which docstring style guide do we use? strictness = short # darglint: how "strict" are we with docstring completeness? docstring-style = numpy # darglint: which docstring style guide do we use? suppress-none-returning = true # flake8-annotations: do we allow un-annotated Nones in returns? mypy-init-return = true # flake8-annotations: do we allow init to have no return annotation? per-file-ignores = # list of case-by-case ignores, see files for details */__init__.py:F401,I */data/*.py:DAR data/*.py:F,I *text_recognizer/util.py:DAR101,F401 *training/run_experiment.py:I202 *app_gradio/app.py:I202 ================================================ FILE: lab07/.github/workflows/pre-commit.yml ================================================ name: pre-commit on: pull_request: push: # allows this Action to be triggered manually workflow_dispatch: jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: python-version: '3.10' - uses: pre-commit/action@v3.0.0 ================================================ FILE: lab07/.pre-commit-config.yaml ================================================ repos: # a set of useful Python-based pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: # list of definitions and supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace # removes any whitespace at the ends of lines - id: check-toml # check toml syntax by loading all toml files - id: check-yaml # check yaml syntax by loading all yaml files - id: check-json # check-json syntax by loading all json files - id: check-merge-conflict # check for files with merge conflict strings args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge - id: check-added-large-files # check that no "large" files have been added args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server - id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.) - id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.) # black python autoformatting - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black # additional configuration of black in pyproject.toml # flake8 python linter with all the fixins - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08) additional_dependencies: [ flake8-bandit, flake8-bugbear, flake8-docstrings, flake8-import-order, darglint, mypy, pycodestyle, pydocstyle] args: ["--config", ".flake8"] # additional configuration of flake8 and extensions in .flake8 # shellcheck-py for linting shell files - repo: https://github.com/shellcheck-py/shellcheck-py rev: v0.8.0.4 hooks: - id: shellcheck ================================================ FILE: lab07/api_serverless/Dockerfile ================================================ # Starting from an official AWS image # Keep any dependencies and versions in this file aligned with the environment.yml and Makefile FROM public.ecr.aws/lambda/python:3.10 # Install Python dependencies COPY requirements/prod.txt ./requirements.txt RUN pip install --upgrade pip==23.1.2 RUN pip install -r requirements.txt # Copy only the relevant directories and files # note that we use a .dockerignore file to avoid copying logs etc. COPY text_recognizer/ ./text_recognizer COPY api_serverless/api.py ./api.py CMD ["api.handler"] ================================================ FILE: lab07/api_serverless/__init__.py ================================================ """Cloud function-backed API for paragraph recognition.""" ================================================ FILE: lab07/api_serverless/api.py ================================================ """AWS Lambda function serving text_recognizer predictions.""" import json from PIL import ImageStat from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer import text_recognizer.util as util model = ParagraphTextRecognizer() def handler(event, _context): """Provide main prediction API.""" print("INFO loading image") image = _load_image(event) if image is None: return {"statusCode": 400, "message": "neither image_url nor image found in event"} print("INFO image loaded") print("INFO starting inference") pred = model.predict(image) print("INFO inference complete") image_stat = ImageStat.Stat(image) print("METRIC image_mean_intensity {}".format(image_stat.mean[0])) print("METRIC image_area {}".format(image.size[0] * image.size[1])) print("METRIC pred_length {}".format(len(pred))) print("INFO pred {}".format(pred)) return {"pred": str(pred)} def _load_image(event): event = _from_string(event) event = _from_string(event.get("body", event)) image_url = event.get("image_url") if image_url is not None: print("INFO url {}".format(image_url)) return util.read_image_pil(image_url, grayscale=True) else: image = event.get("image") if image is not None: print("INFO reading image from event") return util.read_b64_image(image, grayscale=True) else: return None def _from_string(event): if isinstance(event, str): return json.loads(event) else: return event ================================================ FILE: lab07/app_gradio/Dockerfile ================================================ # The "buster" flavor of the official docker Python image is based on Debian and includes common packages. # Keep any dependencies and versions in this file aligned with the environment.yml and Makefile FROM python:3.10-buster # Create the working directory # set -x prints commands and set -e causes us to stop on errors RUN set -ex && mkdir /repo WORKDIR /repo # Install Python dependencies COPY requirements/prod.txt ./requirements.txt RUN pip install --upgrade pip==23.1.2 RUN pip install -r requirements.txt ENV PYTHONPATH ".:" # Copy only the relevant directories # note that we use a .dockerignore file to avoid copying logs etc. COPY text_recognizer/ ./text_recognizer COPY app_gradio/ ./app_gradio # Use docker run -it --rm -p $PORT:11717 to run the web server and listen on host $PORT # add --help to see help for the Python script ENTRYPOINT ["python3", "app_gradio/app.py", "--port", "11717"] ================================================ FILE: lab07/app_gradio/README.md ================================================ ## Full-Paragraph Optical Character Recognition For more on how this application works, [check out the GitHub repo](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022). ### Flagging If the model outputs in the top-right are wrong in some way, let us know by clicking the "flagging" buttons underneath. We'll analyze the results with [Gantry](https://gantry.io/blog/introducing-gantry/) and use them to improve the model! ================================================ FILE: lab07/app_gradio/__init__.py ================================================ ================================================ FILE: lab07/app_gradio/app.py ================================================ """Provide an image of handwritten text and get back out a string!""" import argparse import json import logging import os from pathlib import Path from typing import Callable import gradio as gr from PIL import ImageStat from PIL.Image import Image import requests from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer import text_recognizer.util as util os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU logging.basicConfig(level=logging.INFO) APP_DIR = Path(__file__).resolve().parent # what is the directory for this application? FAVICON = APP_DIR / "1f95e.png" # path to a small image for display in browser tab and social media README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown DEFAULT_PORT = 11700 def main(args): predictor = PredictorBackend(url=args.model_url) frontend = make_frontend( predictor.run, ) frontend.launch( server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104 server_port=args.port, # set a port to bind to, failing if unavailable share=True, # should we create a (temporary) public link on https://gradio.app? favicon_path=FAVICON, # what icon should we display in the address bar? ) def make_frontend( fn: Callable[[Image], str], ): """Creates a gradio.Interface frontend for an image to text function.""" examples_dir = Path("text_recognizer") / "tests" / "support" / "paragraphs" example_fnames = [elem for elem in os.listdir(examples_dir) if elem.endswith(".png")] example_paths = [examples_dir / fname for fname in example_fnames] examples = [[str(path)] for path in example_paths] allow_flagging = "never" readme = _load_readme(with_logging=allow_flagging == "manual") # build a basic browser interface to a Python function frontend = gr.Interface( fn=fn, # which Python function are we interacting with? outputs=gr.components.Textbox(), # what output widgets does it need? the default text widget # what input widgets does it need? we configure an image widget inputs=gr.components.Image(type="pil", label="Handwritten Text"), title="📝 Text Recognizer", # what should we display at the top of the page? thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media? description=__doc__, # what should we display just above the interface? article=readme, # what long-form content should we display below the interface? examples=examples, # which potential inputs should we provide? cache_examples=False, # should we cache those inputs for faster inference? slows down start allow_flagging=allow_flagging, # should we show users the option to "flag" outputs? ) return frontend class PredictorBackend: """Interface to a backend that serves predictions. To communicate with a backend accessible via a URL, provide the url kwarg. Otherwise, runs a predictor locally. """ def __init__(self, url=None): if url is not None: self.url = url self._predict = self._predict_from_endpoint else: model = ParagraphTextRecognizer() self._predict = model.predict def run(self, image): pred, metrics = self._predict_with_metrics(image) self._log_inference(pred, metrics) return pred def _predict_with_metrics(self, image): pred = self._predict(image) stats = ImageStat.Stat(image) metrics = { "image_mean_intensity": stats.mean, "image_median": stats.median, "image_extrema": stats.extrema, "image_area": image.size[0] * image.size[1], "pred_length": len(pred), } return pred, metrics def _predict_from_endpoint(self, image): """Send an image to an endpoint that accepts JSON and return the predicted text. The endpoint should expect a base64 representation of the image, encoded as a string, under the key "image". It should return the predicted text under the key "pred". Parameters ---------- image A PIL image of handwritten text to be converted into a string. Returns ------- pred A string containing the predictor's guess of the text in the image. """ encoded_image = util.encode_b64_image(image) headers = {"Content-type": "application/json"} payload = json.dumps({"image": "data:image/png;base64," + encoded_image}) response = requests.post(self.url, data=payload, headers=headers) pred = response.json()["pred"] return pred def _log_inference(self, pred, metrics): for key, value in metrics.items(): logging.info(f"METRIC {key} {value}") logging.info(f"PRED >begin\n{pred}\nPRED >end") def _load_readme(with_logging=False): with open(README) as f: lines = f.readlines() if not with_logging: lines = lines[: lines.index("\n")] readme = "".join(lines) return readme def _make_parser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--model_url", default=None, type=str, help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.", ) parser.add_argument( "--port", default=DEFAULT_PORT, type=int, help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.", ) return parser if __name__ == "__main__": parser = _make_parser() args = parser.parse_args() main(args) ================================================ FILE: lab07/app_gradio/tests/test_app.py ================================================ import json import os import requests from app_gradio import app from text_recognizer import util os.environ["CUDA_VISIBLE_DEVICES"] = "" TEST_IMAGE = "text_recognizer/tests/support/paragraphs/a01-077.png" def test_local_run(): """A quick test to make sure we can build the app and ping the API locally.""" backend = app.PredictorBackend() frontend = app.make_frontend(fn=backend.run) # run the UI without blocking frontend.launch(share=False, prevent_thread_lock=True) local_url = frontend.local_url get_response = requests.get(local_url) assert get_response.status_code == 200, get_response.content image_b64 = util.encode_b64_image(util.read_image_pil(TEST_IMAGE)) local_api = f"{local_url}api/predict" headers = {"Content-Type": "application/json"} payload = json.dumps({"data": ["data:image/png;base64," + image_b64]}) post_response = requests.post(local_api, data=payload, headers=headers) assert post_response.status_code == 200, post_response.content ================================================ FILE: lab07/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab07/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab07/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab07/notebooks/lab03_transformers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 03: Transformers and Paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The fundamental reasons why the Transformer is such\n", "a powerful and popular architecture\n", "- Core intuitions for the behavior of Transformer architectures\n", "- How to use a convolutional encoder and a Transformer decoder to recognize\n", "entire paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 3\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Transformers?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal in building a text recognizer is to take a two-dimensional image\n", "and convert it into a one-dimensional sequence of characters\n", "from some alphabet." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional neural networks,\n", "discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n", "are great at encoding images,\n", "taking them from their raw pixel values\n", "to a more semantically meaningful numerical representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But how do we go from that to a sequence of letters?\n", "And what's especially tricky:\n", "the number of letters in an image is separable from its size.\n", "A screenshot of this document has a much higher density of letters\n", "than a close-up photograph of a piece of paper.\n", "How do we get a _variable-length_ sequence of letters,\n", "where the length need have nothing to do with the size of the input tensor?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n", "they were\n", "[originally introduced](https://arxiv.org/abs/1706.03762)\n", "for transforming one sequence into another,\n", "as in machine translation.\n", "This makes them a natural fit for processing language.\n", "\n", "But they have also found success in other domains --\n", "at the time of this writing, large transformers\n", "dominate the\n", "[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n", "that has become a de facto standard for comparing models\n", "and are finding\n", "[application in reinforcement learning](https://arxiv.org/abs/2106.01345)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we will use a Transformer as a key component of our final architecture:\n", "we will encode our input images with a CNN\n", "and then read them out into a text sequence with a Transformer.\n", "\n", "Before trying out this new model,\n", "let's first get an understanding of why the Transformer architecture\n", "has become so popular by walking through its history\n", "and then get some intuition for how it works\n", "by looking at some\n", "[recent work](https://transformer-circuits.pub/)\n", "on explaining the behavior of both toy models and state-of-the-art language models." ] }, { "cell_type": "markdown", "metadata": { "id": "kmKqjbvd-Mj3" }, "source": [ "## Why not convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "SRqkUMdM-OxU" }, "source": [ "In the ancient beforetimes (i.e. 2016),\n", "the best models for natural language processing were all\n", "_recurrent_ neural networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional networks were also occasionally used,\n", "but they suffered from a serious issue:\n", "their architectural biases don't fit text.\n", "\n", "First, _translation equivariance_ no longer holds.\n", "The beginning of a piece of text is often quite different from the middle,\n", "so the absolute position matters.\n", "\n", "Second, _locality_ is not as important in language.\n", "The name of a character that hasn't appeared in thousands of pages\n", "can become salient when someone asks, \"Whatever happened to\n", "[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n", "\n", "Consider interpreting a piece of text like the Python code below:\n", "```python\n", "def do(arg1, arg2, arg3):\n", " a = arg1 + arg2\n", " b = arg3[:3]\n", " c = a * b\n", " return c\n", "\n", "print(do(1, 1, \"ayy lmao\"))\n", "```\n", "\n", "After a `(` we expect a `)`,\n", "but possibly very long afterwards,\n", "[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n", "and similarly we expect a `]` at some point after a `[`.\n", "\n", "For translation variance, consider\n", "that we interpret `*` not by\n", "comparing it to its neighbors\n", "but by looking at `a` and `b`.\n", "We mix knowledge learned through experience\n", "with new facts learned while reading --\n", "also known as _in-context learning_.\n", "\n", "In a longer text,\n", "[e.g. the one you are reading now](./lab03_transformers.ipynb),\n", "the translation variance of text is clearer.\n", "Every lab notebook begins with the same header,\n", "setting up the environment,\n", "but that header never appears elsewhere in the notebook.\n", "Later positions need to be processed in terms of the previous entries.\n", "\n", "Unlike an image, we cannot simply rotate or translate our \"camera\"\n", "and get a new valid text.\n", "[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n", "that can be read without regard to position." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The field of formal language theory,\n", "which has deep mutual influence with computer science,\n", "gives one way of explaining the issues with convolutional networks:\n", "they can only understand languages with _finite contexts_,\n", "where all the information can be found within a finite window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The immediate solution, drawing from the connections to computer science, is\n", "[recursion](https://www.google.com/search?q=recursion).\n", "A network whose output on the final entry of the sequence is a recursive function\n", "of all the previous entries can build up knowledge\n", "as it reads the sequence and treat early entries quite differently than it does late ones." ] }, { "cell_type": "markdown", "metadata": { "id": "aa6cbTlImkEh" }, "source": [ "In pseudo-code, such a _recurrent neural network_ module might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "lKtBoPnglPrW" }, "source": [ "```python\n", "def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n", " next_inputs = input_module(xs[-1])\n", " next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n", " return output_module(next_inputs, next_hiddens)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "IbJPSMnEm516" }, "source": [ "If you've had formal computer science training,\n", "then you may be familiar with the power of recursion,\n", "e.g. the\n", "[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n", "that gave its name to the now much better-known\n", "[startup incubator](https://www.ycombinator.com/).\n", "\n", "The particular form of recursion used by\n", "recurrent neural networks implements a\n", "[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n", "\n", "> If you've know a lot of computer science,\n", "you might be concerned by this connection.\n", "What about other\n", "[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n", "Where are the neural network architectures for differentiable\n", "[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n", "Check out Graph Neural Networks,\n", "[which implement dynamic programming](https://arxiv.org/abs/2203.15544)." ] }, { "cell_type": "markdown", "metadata": { "id": "63mMTbEBpVuE" }, "source": [ "Recurrent networks are able to achieve\n", "[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n", "\n", "There are many popular recurrent architectures,\n", "from the beefy and classic\n", "[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n", "and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n", "([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n", "all of which have roughly similar capabilities but\n", "[some of which are easier to train](https://arxiv.org/abs/1611.09913)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwQHVTIslOku" }, "source": [ "In the same sense that MLPs can model \"any\" feedforward function,\n", "in principle even basic RNNs\n", "[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n", "\n", "In particular they can model any\n", "[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n", "which is a formal way of saying that they can in principle\n", "do anything a computer is capable of doing.\n", "\n", "The question is then..." ] }, { "cell_type": "markdown", "metadata": { "id": "3J8EoGN3pu7P" }, "source": [ "## Why aren't we all using RNNs?" ] }, { "cell_type": "markdown", "metadata": { "id": "TDwNWaevpt_3" }, "source": [ "The guarantees that MLPs can model any function\n", "or that RNNs can model Turing machines\n", "provide decent intuition but are not directly practically useful.\n", "Among other reasons, they don't guarantee learnability --\n", "that starting from random parameters we can find the parameters\n", "that implement a given function.\n", "The\n", "[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n", "than would seem from basic theoretical and empirical analysis.\n", "\n", "One way of understanding capacity to model language is\n", "[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n", "In this model of formal languages,\n", "Turing machines sit at the top\n", "([practically speaking](https://arxiv.org/abs/math/0209332)).\n", "\n", "With better mathematical models,\n", "RNNs and LSTMs can be shown to be\n", "[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n", "with RNNs looking more like\n", "[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n", "and LSTMs coming in\n", "[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n", "\n", "More controversially:\n", "the Chomsky hierarchy is great for understanding syntax and grammar,\n", "which makes it great for building parsers\n", "and working with formal languages,\n", "but the goal in _natural_ language processing is to understand _natural_ language.\n", "Most humans' natural language is far from strictly grammatical,\n", "but that doesn't mean it is nonsense.\n", "\n", "And to really \"understand\" language means\n", "to understand its semantic content, which is fuzzy.\n", "The most important thing for handling the fuzzy semantic content\n", "of language is not whether you can recall\n", "[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n", "but whether you can model probabilistic relationships between concepts\n", "in addition to grammar and syntax." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These both leave theoretical room for improvement over current recurrent\n", "language and sequence models.\n", "\n", "But the real cause of the rise of Transformers is that..." ] }, { "cell_type": "markdown", "metadata": { "id": "Dsu1ebvAp-3Z" }, "source": [ "## Transformers are designed to train fast at scale on contemporary hardware." ] }, { "cell_type": "markdown", "metadata": { "id": "c4abU5adsPGs" }, "source": [ "The Transformer architecture has several important features,\n", "discussed below,\n", "but one of the most important reasons why it is successful\n", "is because it can be more easily trained at scale.\n", "\n", "This scalability is the focus of the discussion in the paper\n", "that introduced the architecture,\n", "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", "and\n", "[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n", "\n", "The recursion in RNNs is inherently sequential:\n", "the dependence on the outputs from earlier in the sequence\n", "means computations within an example cannot be parallelized.\n", "\n", "So RNNs must batch across examples to scale,\n", "but as sequence length grows this hits memorybandwidth limits.\n", "Serving up large batches quickly with good randomness guarantees\n", "is also hard to optimize,\n", "especially in distributed settings.\n", "\n", "The Transformer architecture,\n", "on the other hand,\n", "can be readily parallelized within a single example sequence,\n", "in addition to parallelization across batches.\n", "This can lead to massive performance gains for a fixed scale,\n", "which means larger, higher capacity models\n", "can be trained on larger datasets." ] }, { "cell_type": "markdown", "metadata": { "id": "_Mzk2haFC_G1" }, "source": [ "How does the architecture achieve this parallelizability?\n", "\n", "Let's start with the architecture diagram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u59eu4snLQfp" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n", "\n", "display.Image(url=base_url + \"/aiayn-figure-1.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ez-XEQ7M0UlR" }, "source": [ "> To head off a bit of confusion\n", " in case you've worked with Transformer architectures before:\n", " the original \"Transformer\" is an encoder/decoder architecture.\n", " Many LLMs, like GPT models, are decoder only,\n", " because this has turned out to scale well,\n", " and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n", " it's all text anyways.\n", " We, however, will be using them across modalities,\n", " so we need an explicit encoder,\n", " as above. " ] }, { "cell_type": "markdown", "metadata": { "id": "ok4ksBi4vp89" }, "source": [ "First focusing on the encoder (left):\n", "the encoding at a given position is a function of all previous inputs.\n", "But it is not a function of the previous _encodings_:\n", "we produce the encodings \"all at once\"." ] }, { "cell_type": "markdown", "metadata": { "id": "RPN7C-_OqzHP" }, "source": [ "The decoder (right) does use previous \"outputs\" as its inputs,\n", "but those outputs are not the vectors of layer activations\n", "(aka embeddings)\n", "that are produced by the network.\n", "They are instead the processed outputs,\n", "after a `softmax` and an `argmax`.\n", "\n", "We could obtain these outputs by processing the embeddings,\n", "much like in a recurrent architecture.\n", "In fact, that is one way that Transformers are run.\n", "It's what happens in the `.forward` method\n", "of the model we'll be training for character recognition:\n", "`ResnetTransformer`." ] }, { "cell_type": "markdown", "metadata": { "id": "L5_2WMmtDnJn" }, "source": [ "Let's look at that forward method\n", "and connect it to the diagram." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FR5pk4kEyCGg" }, "outputs": [], "source": [ "from text_recognizer.models import ResnetTransformer\n", "\n", "\n", "ResnetTransformer.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "-J5UFDoPzPbq" }, "source": [ "`.encode` happens first -- that's the left side of diagram.\n", "\n", "The encoder can in principle be anything\n", "that produces a sequence of fixed-length vectors,\n", "but here it's\n", "[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n", "\n", "Then we start iterating over the sequence\n", "in the `for` loop.\n", "\n", "Focus on the first few lines of code.\n", "We apply `.decode` (right side of diagram)\n", "to the outputs so far.\n", "\n", "Once we have a new `output`, we apply `.argmax`\n", "to turn the logits into a concrete prediction of\n", "a particular token.\n", "\n", "This is added as the last output token\n", "and then the loop happens again." ] }, { "cell_type": "markdown", "metadata": { "id": "LTcy8-rV1dHr" }, "source": [ "Run this way, our model looks very much like a recurrent architecture:\n", "we call the model on its own outputs\n", "to generate the next value.\n", "These types of models are also referred to as\n", "[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n", "because we predict (as we do in _regression_)\n", "the next value based on our own (_auto_) output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But Transformers are designed to be _trained_ more scalably than RNNs,\n", "not necessarily to _run inference_ more scalably,\n", "and it's actually not the case that our model's `.forward` is called during training." ] }, { "cell_type": "markdown", "metadata": { "id": "eCxMSAWmEKBt" }, "source": [ "Let's look at what happens during training\n", "by checking the `training_step`\n", "of the `LightningModule`\n", "we use to train our Transformer models,\n", "the `TransformerLitModel`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0o7q8N7P2w4H" }, "outputs": [], "source": [ "from text_recognizer.lit_models import TransformerLitModel\n", "\n", "TransformerLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "1VgNNOjvzC4y" }, "source": [ "Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`." ] }, { "cell_type": "markdown", "metadata": { "id": "tz-6NGPR4dUr" }, "source": [ "Let's look at `.teacher_forward`,\n", "and in particular its type signature:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ILc2oWET4i2Z" }, "outputs": [], "source": [ "TransformerLitModel.teacher_forward??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`." ] }, { "cell_type": "markdown", "metadata": { "id": "lf32lpgrDb__" }, "source": [ "This is known as \"teacher forcing\".\n", "The \"teacher\" signal is \"forcing\"\n", "the model to behave as though\n", "it got the answer right.\n", "\n", "[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n", "It's more effective here\n", "because the right teaching signal\n", "for our network is the target data,\n", "which we have access to during training,\n", "whereas in an RNN the best teaching signal\n", "would be the target embedding vector,\n", "which we do not know.\n", "\n", "During inference, when we don't have access to the ground truth,\n", "we revert to the autoregressive `.forward` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This \"trick\" allows Transformer architectures to readily scale\n", "up models to the parameter counts\n", "[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)." ] }, { "cell_type": "markdown", "metadata": { "id": "BAjqpJm9uUuU" }, "source": [ "## Is there more to Transformers more than just a training trick?" ] }, { "cell_type": "markdown", "metadata": { "id": "kWCYXeHv7Qc9" }, "source": [ "[Very](https://arxiv.org/abs/2005.14165),\n", "[very](https://arxiv.org/abs/1909.08053),\n", "[very](https://arxiv.org/abs/2205.01068)\n", "large Transformer models have powered the most recent wave of exciting results in ML, like\n", "[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n", "\n", "They are also the first machine learning models to have come anywhere close to\n", "deserving the term _artificial intelligence_ --\n", "a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is surprising because the models and their training procedure are\n", "(relatively speaking)\n", "pretty _simple_,\n", "even if it doesn't feel that way on first pass." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic Transformer architecture is just a bunch of\n", "dense matrix multiplications and non-linearities --\n", "it's perhaps simpler than a convolutional architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And advances since the introduction of Transformers in 2017\n", "have not in the main been made by\n", "creating more sophisticated model architectures\n", "but by increasing the scale of the base architecture,\n", "or if anything making it simpler, as in\n", "[GPT-type models](https://arxiv.org/abs/2005.14165),\n", "which drop the encoder." ] }, { "cell_type": "markdown", "metadata": { "id": "V1HQS9ey8GMc" }, "source": [ "These models are also trained on very simple tasks:\n", "most LLMs are just trying to predict the next element in the sequence,\n", "given the previous elements --\n", "a task simple enough that Claude Shannon,\n", "father of information theory, was\n", "[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n", "\n", "These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n", "e.g. by scraping the web." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They are also trained in a simple fashion:\n", "first-order stochastic optimizers, like SGD or an\n", "[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n", "intended for the most basic of optimization problems,\n", "that scale more readily than the second-order optimizers\n", "that dominate other areas of optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "Kz9HPDoy7OAl" }, "source": [ "This is\n", "[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n", "of work in ML:\n", "simple, even seemingly wasteful,\n", "architectures that scale well and are robust\n", "to implementation details\n", "eventually outstrip more clever but\n", "also more finicky approaches that are harder to scale.\n", "This lesson has led some to declare that\n", "[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n", "in machine learning, and perhaps even in artificial intelligence." ] }, { "cell_type": "markdown", "metadata": { "id": "SdN9o2Y771YZ" }, "source": [ "> That is not to say that because the algorithms are relatively simple,\n", " training a model at this scale is _easy_ --\n", " [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n", " [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n", " [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n", " But choosing the simplest algorithm at every step makes solving the scaling problem feasible." ] }, { "cell_type": "markdown", "metadata": { "id": "baVGf6gKFOvs" }, "source": [ "The importance of scale is the key lesson from the Transformer architecture,\n", "far more than any theoretical considerations\n", "or any of the implementation details.\n", "\n", "That said, these large Transformer models are capable of\n", "impressive behaviors and understanding how they achieve them\n", "is of intellectual interest.\n", "Furthermore, like any architecture,\n", "there are common failure modes,\n", "of the model and of the modelers who use them,\n", "that need to be taken into account." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "Below, we'll cover two key intuitions about Transformers:\n", "Transformers are _residual_, like ResNets,\n", "and they compose _low rank_ sequence transformations.\n", "Together, this means they act somewhat like a computer,\n", "reading from and writing to a \"tape\" or memory\n", "with a sequence of simple instructions." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "We'll also cover a surprising implementation detail:\n", "despite being commonly used for sequence modeling,\n", "by default the architecture is _position insensitive_." ] }, { "cell_type": "markdown", "metadata": { "id": "uni0VTCr9lev" }, "source": [ "### Intuition #1: Transformers are highly residual." ] }, { "cell_type": "markdown", "metadata": { "id": "0MoBt-JLJz-d" }, "source": [ "> The discussion of these inuitions summarizes the discussion in\n", "[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n", "from\n", "[Anthropic](https://www.anthropic.com/),\n", "an AI safety and research company.\n", "The figures below are from that blog post.\n", "It is the spiritual successor to the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "covered in\n", "[Lab 02b](https://lab02b-colab).\n", "If you want to truly understand Transformers,\n", "we highly recommend you check it out,\n", "including the\n", "[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "UUbNVvM5Ferm" }, "source": [ "It's easy to see that ResNets are residual --\n", "it's in the name, after all.\n", "\n", "But Transformers are,\n", "in some sense,\n", "even more closely tied to residual computation\n", "than are ResNets:\n", "ResNets and related architectures include downsampling,\n", "so there is not a direct path from inputs to outputs.\n", "\n", "In Transformers, the exact same shape is maintained\n", "from the moment tokens are embedded,\n", "through dozens or hundreds of intermediate layers,\n", "and until they are \"unembedded\" into class logits.\n", "The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n", "\n", "The resiudal stream is easy to see with a change of perspective.\n", "Instead of the usual architecture diagram above,\n", "which emphasizes the layers acting on the tensors,\n", "consider this alternative view,\n", "which emphasizes the tensors as they pass through the layers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HRMlVguKKW6y" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-residual-view.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9K3N7ilVkB3" }, "source": [ "For definitions of variables and terms, see the\n", "[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)." ] }, { "cell_type": "markdown", "metadata": { "id": "arvciE-kKd_L" }, "source": [ "Note that this is a _decoder-only_ Transformer architecture --\n", "so it should be compared with the right-hand side of the original architecture diagram above." ] }, { "cell_type": "markdown", "metadata": { "id": "wvrRMd_RKp_G" }, "source": [ "Notice that outputs of the attention blocks \n", "and of the MLP layers are\n", "added to their inputs, as in a ResNet.\n", "These operations are represented as \"Add & Norm\" layers in the classical diagram;\n", "normalization is ignored here for simplicity." ] }, { "cell_type": "markdown", "metadata": { "id": "o8n_iT-FFAbK" }, "source": [ "This total commitment to residual operations\n", "means the size of the embeddings\n", "(referred to as the \"model dimension\" or the \"embedding dimension\",\n", "here and below `d_model`)\n", "stays the same throughout the entire network.\n", "\n", "That means, for example,\n", "that the output of each layer can be used as input to the \"unembedding\" layer\n", "that produces logits.\n", "We can read out the computations of intermediate layers\n", "just by passing them through the unembedding layer\n", "and examining the logit tensor.\n", "See\n", "[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n", "for detailed experiments and interactive notebooks.\n", "\n", "In short, we observe a sort of \"progressive refinement\"\n", "of the next-token prediction\n", "as the embeddings proceed, depthwise, through the network." ] }, { "cell_type": "markdown", "metadata": { "id": "Ovh_3YgY9z2h" }, "source": [ "### Intuition #2 Transformer heads learn low rank transformations." ] }, { "cell_type": "markdown", "metadata": { "id": "XpNmozlnOdPC" }, "source": [ "In the original paper and in\n", "most presentations of Transformers,\n", "the attention layer is written like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PA7me8gNP5LE" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In pseudo-typed PyTorch (based loosely on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n", "that looks like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Oeict_6wGJgD" }, "source": [ "```python\n", "def classic_attention(\n", " Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " K: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " return torch.softmax(Q @ K.T) @ V\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "8pewU90DSuOR" }, "source": [ "This is effectively exactly\n", "how it is written\n", "in PyTorch,\n", "apart from implementation details\n", "(look for `bmm` for the matrix multiplications and a `softmax` call):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrgTpKFvOhwc" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "F._scaled_dot_product_attention??" ] }, { "cell_type": "markdown", "metadata": { "id": "ebDXZ0tlSe7g" }, "source": [ "But the best way to write an operation so that a computer can execute it quickly\n", "is not necessarily the best way to write it so that a human can understand it --\n", "otherwise we'd all be coding in assembly.\n", "\n", "And this is a strange way to write it --\n", "you'll notice that what we normally think of\n", "as the \"inputs\" to the layer are not shown.\n", "\n", "We can instead write out the attention layer\n", "as a function of the inputs $x$.\n", "We write it for a single \"attention head\".\n", "Each attention layer includes a number of heads\n", "that read and write from the residual stream\n", "simultaneously and independently.\n", "We also add the output layer weights $W_O$\n", "and we get:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LuFNR67tQpsf" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SVnBjjfOLwxP" }, "source": [ "or, in pseudo-typed PyTorch:" ] }, { "cell_type": "markdown", "metadata": { "id": "LmpOm-HfGaNz" }, "source": [ "```python\n", "def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n", " key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n", " key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n", " # maps queries of residual stream to keys from residual stream, independent of position\n", "\n", " value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n", " output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n", " value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n", " # transformation applied to each token, regardless of position\n", "\n", " attention_logits = x.T @ key_query_circuit @ x\n", " attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n", " # maps positions to positions, often very sparse\n", "\n", " value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n", "\n", " return attention_map @ value_output # transformed tokens filtered by attention map\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "dC0eqxZ6UAGT" }, "source": [ "Consider the `key_query_circuit`\n", "and `value_output_circuit`\n", "matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n", "\n", "The key/query dimension, `d_head`\n", "is small relative to the model's dimension, `d_model`,\n", "so $W_{QK}$ and $W_{OV}$ are very low rank,\n", "[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n", "that they factorize into two matrices,\n", "one with a smaller number of rows\n", "and another with a smaller number of columns.\n", "That number is called the _rank_.\n", "\n", "When computing, these matrices are better represented via their components,\n", "rather than computed directly,\n", "which leads to the normal implementation of attention.\n", "\n", "In a large language model,\n", "the ratio of residual stream dimension, `d_model`, to\n", "the dimension of a single head, `d_head`, is huge, often 100:1.\n", "That means each query, key, and value computed at a position\n", "is a fairly simple, low-dimensional feature of the residual stream at that position.\n", "\n", "For visual intuition,\n", "we compare what a matrix with a rank 100th of full rank looks like,\n", "relative to a full rank matrix of the same size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LUbojJMiW2C" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "\n", "low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n", "full_rank = torch.randn(100, 100)\n", "plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n", "plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");" ] }, { "cell_type": "markdown", "metadata": { "id": "lqBst92-OVka" }, "source": [ "The pattern in the first matrix is very simple,\n", "relative to the pattern in the second matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "SkCGrs9EiVh4" }, "source": [ "Another feature of low rank transformations is\n", "that they have a large nullspace or kernel --\n", "these are directions we can move the input without changing the output.\n", "\n", "That means that many changes to the residual stream won't affect the behavior of this head at all." ] }, { "cell_type": "markdown", "metadata": { "id": "UVz2dQgzhD4p" }, "source": [ "### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)." ] }, { "cell_type": "markdown", "metadata": { "id": "hVlzwR03m8mC" }, "source": [ "The combination of residuality\n", "(changes are added to the current input)\n", "and low rank\n", "(only a small subspace is changed by each head)\n", "drastically changes the intuition about Transformers." ] }, { "cell_type": "markdown", "metadata": { "id": "qqjZI2jKe6HH" }, "source": [ "Rather than being an \"embedding of a token in its context\",\n", "the residual stream becomes something more like a memory or a scratchpad:\n", "one layer reads a small bit of information from the stream\n", "and writes a small bit of information back to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YIBkxlqepjc" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-layer-residual.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtsKhkLfk00l" }, "source": [ "The residual stream works like a memory because it is roomy enough\n", "that these actions need not interfere:\n", "the subspaces targeted by reads and writes are small relative to the ambient space,\n", "so they can\n", "\n", "Additionally, the dimension of each head is still in the 100s in large models,\n", "and\n", "[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n", "in them, so the number of effectively degrees of freedom is\n", "actually larger than the dimension.\n", "This phenomenon allows high-dimensional tensors to serve as\n", "[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n", "There are\n", "[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n", "\n", "Together, this means an early layer can write information to the stream\n", "that can be used by later layers -- by many of them at once, possibly much later.\n", "Later layers can learn to edit this information,\n", "e.g. deleting it,\n", "if doing so reduces the loss,\n", "but by default the information is preserved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EragIygzJg86" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-stream-read-write.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "oKIaUZjwkpW7" }, "source": [ "Lastly, the softmax in the attention has a sparsifying effect,\n", "and so many attention heads are reading from \n", "just one token and writing to just one other token." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dN6VcJqIMKnB" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-token-to-token.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Repeatedly reading information from an external memory\n", "and using it to decide which operation to perform\n", "and where to write the results\n", "is at the core of the\n", "[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n", "For a concrete example, the\n", "[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n", "includes a dissection of a form of \"pointer arithmetic\"\n", "that appears in some models." ] }, { "cell_type": "markdown", "metadata": { "id": "0kLFh7Mvnolr" }, "source": [ "This point of view seems\n", "very promising for explaining numerous\n", "otherwise perhaps counterintuitive features of Transformer models.\n", "\n", "- This framework predicts lots that Transformers will readily copy-and-paste information,\n", "which might explain phenomena like\n", "[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n", "\n", "- It also readily explains\n", "[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n", "an important component of why Transformers perform well on medium-length texts\n", "and in few-shot learning.\n", "\n", "- Transformers also perform better on reasoning tasks when the text\n", "[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n", "is added to their input prompt.\n", "This is partly due to the fact that that prompt is associated,\n", "in the dataset, with clearer reasoning,\n", "and since the models are trained to predict which tokens tend to appear\n", "after an input, they tend to produce better reasoning with that prompt --\n", "an explanation purely in terms of sequence modeling.\n", "But it also gives the Transformer license to generate a large number of tokens\n", "that act to store intermediate information,\n", "making for a richer residual stream\n", "for reading and writing." ] }, { "cell_type": "markdown", "metadata": { "id": "RyLRzgG-93yB" }, "source": [ "### Implementation detail: Transformers are position-insensitive by default." ] }, { "cell_type": "markdown", "metadata": { "id": "oR6PnrlA_hJ2" }, "source": [ "In the attention calculation\n", "each token can query each other token,\n", "with no regard for order.\n", "Furthermore, the construction of queries, keys, and values\n", "is based on the content of the embedding vector,\n", "which does not automatically include its position.\n", "\"dog bites man\" and \"man bites dog\" are identical, as in\n", "[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n", "\n", "For most sequences,\n", "this is unacceptable:\n", "absolute and relative position matter\n", "and we cannot use the future to predict the past.\n", "\n", "We need to add two pieces to get a Transformer architecture that's usable for next-token prediction." ] }, { "cell_type": "markdown", "metadata": { "id": "EWHxGJz2-6ZK" }, "source": [ "First, the simpler piece:\n", "\"causal\" attention,\n", "so-named because it ensures that values earlier in the sequence\n", "are not influenced by later values, which would\n", "[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)." ] }, { "cell_type": "markdown", "metadata": { "id": "0c42xi6URYB4" }, "source": [ "The most common solution is straightforward:\n", "we calculate attention between all tokens,\n", "then throw out non-causal values by \"masking\" them\n", "(this is before applying the softmax,\n", "so masking means adding $-\\infty$).\n", "\n", "This feels wasteful --\n", "why are we calculating values we don't need?\n", "Trying to be smarter would be harder,\n", "and might rely on operations that aren't as optimized as\n", "matrix multiplication and addition.\n", "Furthermore, it's \"only\" twice as many operations,\n", "so it doesn't even show up in $O$-notation.\n", "\n", "A sample attention mask generated by our code base is shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NXaWe6pT-9jV" }, "outputs": [], "source": [ "from text_recognizer.models import transformer_util\n", "\n", "\n", "attention_mask = transformer_util.generate_square_subsequent_mask(100)\n", "\n", "ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n", "plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n", "print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This solves our causality problem,\n", "but we still don't have positional information." ] }, { "cell_type": "markdown", "metadata": { "id": "ZamUE4WIoGS2" }, "source": [ "The standard technique\n", "is to add alternating sines and cosines\n", "of increasing frequency to the embeddings\n", "(there are\n", "[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n", "most notably\n", "[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n", "Each position in the sequence is then uniquely identifiable\n", "from the pattern of these values.\n", "\n", "> Furthermore, for the same reason that\n", " [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n", " translations, e.g. relative positions, are fairly easy to express as linear transformations\n", " of sines and cosines)." ] }, { "cell_type": "markdown", "metadata": { "id": "IDG2uOsaELU0" }, "source": [ "We superimpose this positional information on our embeddings.\n", "Note that because the model is residual,\n", "this position information will be by default preserved\n", "as it passes through the network,\n", "so it doesn't need to be repeatedly added." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what this positional encoding looks like in our codebase:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Zk62Q-a-1Ax" }, "outputs": [], "source": [ "PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n", "\n", "pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n", "\n", "ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n", "print(pe[:4, :8])" ] }, { "cell_type": "markdown", "metadata": { "id": "ep2ClIWvqDms" }, "source": [ "When we add the positional information to our embeddings,\n", "both the embedding information and the positional information\n", "is approximately preserved,\n", "as can be visually assessed below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PJuFjoCzC0Y4" }, "outputs": [], "source": [ "fake_embeddings = torch.randn_like(pe) * 0.5\n", "\n", "ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n", "\n", "fake_embeddings_with_pe = fake_embeddings + pe\n", "\n", "plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIzBxDkEmH8" }, "source": [ "A [similar technique](https://arxiv.org/abs/2103.06450)\n", "is used to also incorporate positional information into the image embeddings,\n", "which are flattened before being fed to the decoder." ] }, { "cell_type": "markdown", "metadata": { "id": "HC1N85wl8dvn" }, "source": [ "### Learn more about Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "lJwYxkjTk15t" }, "source": [ "We're only able to give a flavor and an intuition for Transformers here.\n", "\n", "To improve your grasp on the nuts and bolts, check out the\n", "[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n", "which is surprisingly approachable,\n", "as far as ML research papers go.\n", "The\n", "[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n", "adds code and commentary to the original paper,\n", "which makes it even more digestible.\n", "For something even friendlier, check out the\n", "[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n", "by Jay Alammar, which has an accompanying\n", "[video](https://youtu.be/-QH8fRhqFHM).\n", "\n", "Anthropic's work on\n", "[Transformer Circuits](https://transformer-circuits.pub/),\n", "summarized above, has some of the best material\n", "for building theoretical understanding\n", "and is still being updated with extensions and applications of the framework.\n", "The\n", "[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n", "are a great aid for checking and building your understanding.\n", "\n", "But they are fairly math-heavy.\n", "If you have more of a software engineering background, see\n", "Transformer Circuits co-author Nelson Elhage's blog post\n", "[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n", "\n", "For a gentler introduction to the intuition for Transformers,\n", "check out Brandon Rohrer's\n", "[Transformers From Scratch](https://e2eml.school/transformers.html)\n", "tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "qg7zntJES-aT" }, "source": [ "An aside:\n", "the matrix multiplications inside attention dominate\n", "the big-$O$ runtime of Transformers.\n", "So trying to make the attention mechanism more efficient, e.g. linear time,\n", "has generated a lot of research\n", "(review paper\n", "[here](https://arxiv.org/abs/2009.06732)).\n", "Despite drawing a lot of attention, so to speak,\n", "at the time of writing in mid-2022, these methods\n", "[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n", "so it isn't likely to be worth the effort to spend time learning about them\n", "unless you are a Transformer specialist." ] }, { "cell_type": "markdown", "metadata": { "id": "vCjXysEJ8g9_" }, "source": [ "# Using Transformers to read paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "KsfKWnOvqjva" }, "source": [ "Our simple convolutional model for text recognition from\n", "[Lab 02b](https://fsdl.me/lab02b-colab)\n", "could only handle cleanly-separated characters.\n", "\n", "It worked by sliding a LeNet-style CNN\n", "over the image,\n", "predicting a character for each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njLdzBqy-I90" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist_lines = text_recognizer.data.EMNISTLines()\n", "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "\n", "# for sliding, see the for loop over range(S)\n", "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "K0N6yDBQq8ns" }, "source": [ "But unfortunately for us, handwritten text\n", "doesn't come in neatly-separated characters\n", "of equal size, so we trained our model on synthetic data\n", "designed to work with that model." ] }, { "cell_type": "markdown", "metadata": { "id": "hiqUVbj0sxLr" }, "source": [ "Now that we have a better model,\n", "we can work with better data:\n", "paragraphs from the\n", "[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)." ] }, { "cell_type": "markdown", "metadata": { "id": "oizsOAcKs-dD" }, "source": [ "The cell uses our `LightningDataModule`\n", "to download and preprocess this data,\n", "writing results to disk.\n", "We can then spin up `DataLoader`s to give us batches.\n", "\n", "It can take several minutes to run the first time\n", "on commodity machines,\n", "with most time spent extracting the data.\n", "On subsequent runs,\n", "the time-consuming operations will not be repeated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uL9LHbjdsUbm" }, "outputs": [], "source": [ "iam_paragraphs = text_recognizer.data.IAMParagraphs()\n", "\n", "iam_paragraphs.prepare_data()\n", "iam_paragraphs.setup()\n", "xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n", "\n", "iam_paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "nBkFN9bbTm_S" }, "source": [ "Now that we've got a batch,\n", "let's take a look at some samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hqaps8yxtBhU" }, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import wandb\n", "\n", "\n", "def show(y):\n", " y = y.detach().cpu() # bring back from accelerator if it's being used\n", " return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"

\", \"\")\n", "\n", "idx = random.randint(0, len(xs))\n", "\n", "print(show(ys[idx]))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "4dT3UCNzTsoc" }, "source": [ "The `ResnetTransformer` model can run on this data\n", "if passed the `.config`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXL-vIGRr86D" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())" ] }, { "cell_type": "markdown", "metadata": { "id": "MMxa-oWyT01E" }, "source": [ "Our models are now big enough\n", "that we want to make use of GPU acceleration\n", "as much as we can,\n", "even when working on single inputs,\n", "so let's cast to the GPU if we have one." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YyUM8LgvW0w" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "rnt.to(device); xs = xs.to(device); ys = ys.to(device);" ] }, { "cell_type": "markdown", "metadata": { "id": "Y-E3UdD4zUJi" }, "source": [ "First, let's just pass it through the ResNet encoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-LUUtlvaxrvg" }, "outputs": [], "source": [ "resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n", " # resnet is designed for RGB images, so we replicate the input across channels 3 times" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eimgJ5dnywjg" }, "outputs": [], "source": [ "resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n", "plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n", "plt.axis(\"off\"); plt.colorbar(fraction=0.05);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These embeddings, though generated by random, untrained weights,\n", "are not entirely useless.\n", "\n", "Before neural networks could be effectively\n", "trained end to end,\n", "they were often used with frozen random weights\n", "eveywhere except the final layer\n", "(see e.g.\n", "[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n", "[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n", "these methods were still competitive, and\n", "[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n", "provide a\n", "[theoretical basis](https://arxiv.org/abs/2011.14522)\n", "for understanding their performance." ] }, { "cell_type": "markdown", "metadata": { "id": "ye6pW0ETzw2A" }, "source": [ "The final result, though, is repetitive gibberish --\n", "at the bare minimum, we need to train the unembedding/readout layer\n", "in order to get reasonable text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our architecture includes randomization with dropout,\n", "so repeated runs of the cell below will generate different outcomes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu3Pa7gLsFMo" }, "outputs": [], "source": [ "preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gvCXUbskv6XM" }, "outputs": [], "source": [ "print(show(preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without teacher forcing, runtime is also variable from iteration to iteration --\n", "the model stops when it generates an \"end sequence\" or padding token,\n", "which is not deterministic thanks to the dropout layers.\n", "For similar reasons, runtime is variable across inputs.\n", "\n", "The variable runtime of autoregressive generation\n", "is also not great for scaling.\n", "In a distributed setting, as required for large scale,\n", "forward passes need to be synced across devices,\n", "and if one device is generating a batch of much longer sequences,\n", "it will cause all the others to idle while they wait on it to finish." ] }, { "cell_type": "markdown", "metadata": { "id": "t76MSVRXV0V7" }, "source": [ "Let's turn our model into a `TransformerLitModel`\n", "so we can run with teacher forcing.\n", "\n", "> You may be wondering:\n", " why isn't teacher forcing part of the PyTorch module?\n", " In general, the `LightningModule`\n", " should encapsulate things that are needed in training, validation, and testing\n", " but not during inference.\n", " The teacher forcing trick fits this paradigm,\n", " even though it's so critical to what makes Transformers powerful. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8qrHRKHowdDi" }, "outputs": [], "source": [ "import text_recognizer.lit_models\n", "\n", "lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlNaFqR50Oid" }, "source": [ "Now we can use `.teacher_forward` if we also provide the target `ys`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpZdqXS5wn0F" }, "outputs": [], "source": [ "forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Zx9SmsN0QLT" }, "source": [ "This may not run faster than the `rnt.forward`,\n", "since generations are always the maximum possible length,\n", "but runtimes and output lengths are deterministic and constant." ] }, { "cell_type": "markdown", "metadata": { "id": "tu-XNYpi0Qvi" }, "source": [ "Forcing doesn't necessarily make our predictions better.\n", "They remain highly repetitive gibberish." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JcEgify9w0sv" }, "outputs": [], "source": [ "forcing_preds = torch.argmax(forcing_outs, dim=0)\n", "\n", "print(show(forcing_preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xn6GGNzc9a3o" }, "source": [ "## Training the `ResNetTransformer`" ] }, { "cell_type": "markdown", "metadata": { "id": "uvZYsuSyWUXe" }, "source": [ "We're finally ready to train this model on full paragraphs of handwritten text!" ] }, { "cell_type": "markdown", "metadata": { "id": "3cJwC7b720Sd" }, "source": [ "This is a more serious model --\n", "it's the one we use in the\n", "[deployed TextRecognizer application](http://fsdl.me/app).\n", "It's much larger than the models we've seen this far,\n", "so it can easily outstrip available compute resources,\n", "in particular GPU memory.\n", "\n", "To help, we use\n", "[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n", "which shrinks the size of most of our floats by half,\n", "which reduces memory consumption and can speed up computation.\n", "\n", "If your GPU has less than 8GB of available RAM,\n", "you'll see a \"CUDA out of memory\" `RuntimeError`,\n", "which is something of a\n", "[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n", "In this case, you can resolve it by reducing the `--batch_size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w1mXlhfy04Nm" }, "outputs": [], "source": [ "import torch\n", "\n", "gpus = int(torch.cuda.is_available())\n", "\n", "if gpus:\n", " !nvidia-smi\n", "else:\n", " print(\"watch out! working with this model on a typical CPU is not feasible\")" ] }, { "cell_type": "markdown", "metadata": { "id": "os1vW1rPZ1dy" }, "source": [ "Even with an okay GPU, like a\n", "[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n", "a single epoch of training can take over 10 minutes to run.\n", "We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n", "but you can remove those flags to see what full training looks like." ] }, { "cell_type": "markdown", "metadata": { "id": "vnF6dWFn4JlZ" }, "source": [ "It can take a long time (overnight)\n", "to train this model to decent performance on a single GPU,\n", "so we'll focus on other pieces for the exercises.\n", "\n", "> At the time of writing in mid-2022, the cheapest readily available option\n", "for training this model to decent performance on this dataset with this codebase\n", "comes out around $10, using\n", "[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n", "See, for example,\n", "[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n", "and associated experiment.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HufjdUZN0t4l", "scrolled": false }, "outputs": [], "source": [ "%%time\n", "# above %%magic times the cell, useful as a poor man's profiler\n", "\n", "%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n", " --gpus={gpus} --batch_size 16 --precision 16 \\\n", " --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2" ] }, { "cell_type": "markdown", "metadata": { "id": "L6fQ93ju3Iku" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "udb1Ekjx3L63" }, "source": [ "### 🌟 Try out gradient accumulation and other \"training tricks\"." ] }, { "cell_type": "markdown", "metadata": { "id": "kpqViB4p3Wfb" }, "source": [ "Larger batches are helpful not only for increasing parallelization\n", "and amortizing fixed costs\n", "but also for getting more reliable gradients.\n", "Larger batches give gradients with less noise\n", "and to a point, less gradient noise means faster convergence.\n", "\n", "But larger batches result in larger tensors,\n", "which take up more GPU memory,\n", "a resource that is tightly constrained\n", "and device-dependent.\n", "\n", "Does that mean we are limited in the quality of our gradients\n", "due to our machine size?\n", "\n", "Not entirely:\n", "look up the `--accumulate_grad_batches`\n", "argument to the `pl.Trainer`.\n", "You should be able to understand why\n", "it makes it possible to compute the same gradients\n", "you would find for a batch of size `k * N`\n", "on a machine that can only run batches up to size `N`.\n", "\n", "Accumulating gradients across batches is among the\n", "[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n", "Try some of them out!\n", "Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment." ] }, { "cell_type": "markdown", "metadata": { "id": "b2vtkmX830y3" }, "source": [ "### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n", "\n", "While training this model to actually fit the whole dataset is infeasible\n", "as a short exercise on commodity hardware,\n", "it's practical to train this model to memorize a batch of 16 examples.\n", "\n", "Passing `--overfit_batches 1` flag limits the number of training batches to 1\n", "and turns off\n", "[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n", "so that in each epoch, the model just sees the same single batch of data over and over again.\n", "\n", "At first, try training the model to a loss of `2.5` --\n", "it should be doable in 100 epochs or less,\n", "which is just a few minutes on a commodity GPU.\n", "\n", "Once you've got that working,\n", "crank up the number of epochs by a factor of 10\n", "and confirm that the loss continues to go down.\n", "\n", "Some tips:\n", "\n", "- Use `--limit_test_batches 0` to turn off testing.\n", "We don't need it because we don't care about generalization\n", "and it's relatively slow because it runs the model autoregressively.\n", "\n", "- Use `--help` and look through the model class args\n", "to find the arguments used to reduce model size.\n", "\n", "- By default, there's lots of regularization to prevent overfitting.\n", "Look through the args for the model class and data class\n", "for regularization knobs to turn off or down." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab03_transformers.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab07/notebooks/lab04_experiments.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 04: Experiment Management" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How experiment management brings observability to ML model development\n", "- Which features of experiment management we use in developing the Text Recognizer\n", "- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 4\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This lab contains a large number of embedded iframes\n", "that benefit from having a wide window.\n", "The cell below makes the notebook as wide as your browser window\n", "if `full_width` is set to `True`.\n", "Full width is the default behavior in Colab,\n", "so this cell is intended to improve the viewing experience in other Jupyter environments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Why experiment management?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand why we need experiment management for ML development,\n", "let's start by running an experiment.\n", "\n", "We'll train a new model on a new dataset,\n", "using the training script `training/run_experiment.py`\n", "introduced in [Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a CNN encoder and Transformer decoder, as in\n", "[Lab 03](https://fsdl.me/lab03-colab),\n", "but with some changes so we can iterate faster.\n", "We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n", "[Lab02b](https://fsdl.me/lab02b-colab),\n", "and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n", "from text_recognizer.data import IAMLines # processed version split into individual lines\n", "from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n", "\n", "\n", "print(IAM.__doc__)\n", "\n", "# uncomment a line below for details on either class\n", "# IAMLines?? \n", "# LineCNNTransformer??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below will train a model on 10% of the data for two epochs.\n", "\n", "It takes up to a few minutes to run on commodity hardware,\n", "including data download and preprocessing.\n", "As it's running, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%%time\n", "import torch\n", "\n", "\n", "gpus = int(torch.cuda.is_available()) \n", "\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the model trains, we're calculating lots of metrics --\n", "loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n", "and reporting them to the terminal.\n", "\n", "This is achieved by the built-in `.log` method\n", "([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n", "of the `LightningModule`,\n", "and it is a very straightforward way to get basic information about your experiment as it's running\n", "without leaving the context where you're running it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning to read\n", "[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n", "is something of a rite of passage for MLEs, but\n", "let's consider what we can't see here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We're missing all metric values except the most recent --\n", "we can see them as they stream in, but they're constantly overwritten.\n", "We also can't associate them with timestamps, steps, or epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We also don't see any system metrics.\n", "We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n", "without launching a separate process.\n", "And even if we do, those values will also not be saved and timestamped,\n", "so we can't correlate them with other things during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- As we continue to run experiments, changing code and opening new terminals,\n", "even the information we have or could figure out now will disappear.\n", "Say you spot a weird error message during training,\n", "but your session ends and the stdout is gone,\n", "so you don't know exactly what it was.\n", "Can you recreate the error?\n", "Which git branch and commit were you on?\n", "Did you have any uncommitted changes? Which arguments did you pass?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Also, model checkpoints containing the parameter values have been saved to disk.\n", "Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n", "As we run more and more experiments,\n", "we'll want to slice and dice them to see if,\n", "say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to save and log all of this information, and more, in order to make our model training\n", "[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n", "in short, so that we can understand, make decisions about, and debug our model training\n", "by looking at logs and source code, without having to recreate it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n", "1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n", "2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n", "3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n", "4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Local Experiment Tracking with Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n", "and it makes logging very easy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n", "\n", "It can also use a logger to push information elsewhere.\n", "\n", "By default, we use\n", "[TensorBoard](https://www.tensorflow.org/tensorboard)\n", "via the Lightning `TensorBoardLogger`,\n", "which has been saving results to the local disk.\n", "\n", "Let's find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest experiment's directory\n", "# by hand, you can just copy and paste it from the terminal\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n", "filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n", "latest_log" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!ls -lh {latest_log}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, we need to launch a TensorBoard server --\n", "much like we need to launch a Jupyter server to use Jupyter notebooks.\n", "\n", "The cells below load an extension that lets you use TensorBoard inside of a notebook\n", "the same way you'd use it from the command line, and then launch it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext tensorboard" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n", "\n", "port = 11717 # pick an open port on your machine\n", "host = \"0.0.0.0\" # allow connections from the internet\n", " # watch out! make sure you turn TensorBoard off\n", "\n", "%tensorboard --logdir {latest_log} --port {port} --host {host}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should see some charts of metrics over time along with some charting controls.\n", "\n", "You can click around in this interface and explore it if you'd like,\n", "but in the next section, we'll see that there are better tools for experiment management." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've run many experiments on this machine,\n", "you can see all of their results by pointing TensorBoard\n", "at the whole `lightning_logs` directory,\n", "rather than just one experiment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For large numbers of experiments, the management experience is not great --\n", "it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n", "\n", "It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n", "which are important workflows as applications mature and teams grow." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n", "\n", "If you run into any issues with the above cells failing to launch,\n", "especially across iterations of this lab, run this cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorboard.manager\n", "\n", "# get the process IDs for all tensorboard instances\n", "pids = [tb.pid for tb in tensorboard.manager.get_all()]\n", "\n", "done_with_tensorboard = False\n", "\n", "if done_with_tensorboard:\n", " # kill processes\n", " for pid in pids:\n", " !kill {pid} 2> /dev/null\n", " \n", " # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n", " !rm -rf {tensorboard.manager._get_info_dir()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment Management with Weights & Biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do we manage experiments when we hit the limits of local TensorBoard?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorBoard is powerful and flexible and very scalable,\n", "but running it requires engineering effort and babysitting --\n", "you're running a database, writing data to it,\n", "and layering a web application over it.\n", "\n", "This is a fairly common workflow for web developers,\n", "but not so much for ML engineers.\n", "\n", "You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n", "and it's as simple as running the command `tensorboard dev upload`\n", "pointed at your logging directory.\n", "\n", "But there are strict limits to this free service:\n", "1GB of tensor data and 1GB of binary data.\n", "A single Text Recognizer model checkpoint is ~100MB,\n", "and that's not particularly large for a useful model.\n", "\n", "Furthermore, all data is public,\n", "so if you upload the inputs and outputs of your model,\n", "anyone who finds the link can see them.\n", "\n", "Overall, tensorboard.dev works very well for certain academic and open projects\n", "but not for industrial ML." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To avoid that narrow permissions and limits issue,\n", "you could use [git LFS](https://git-lfs.github.com/)\n", "to track the binary data and tensor data,\n", "which is more likely to be sensitive than metrics.\n", "\n", "The Hugging Face ecosystem uses TensorBoard and git LFS.\n", "\n", "It includes the Hugging Face Hub, a git server much like GitHub,\n", "but designed first and foremost for collaboration on models and datasets,\n", "rather than collaboration on code.\n", "For example, the Hugging Face Hub\n", "[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n", "and officially has\n", "[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n", "avoiding the\n", "[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n", "that make using git LFS with GitHub expensive.\n", "\n", "However, we prefer to avoid mixing software version control and experiment management.\n", "\n", "First, using the Hub requires maintaining an additional git remote,\n", "which is a hard ask for many engineering teams.\n", "\n", "Secondly, git-style versioning is an awkward fit for logging --\n", "is it really sensible to create a new commit for each logging event while you're watching live?\n", "\n", "Instead, we prefer to use systems that solve experiment management with _databases_." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n", "The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n", "tool is [MLflow](https://github.com/mlflow/mlflow/)\n", "and there are a number of\n", "[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n", "\n", "These tools generally avoid any need to worry about hosting\n", "(unless data governance rules require a self-hosted version).\n", "\n", "For a sampling of publicly-posted opinions on experiment management tools,\n", "see these discussions from Reddit:\n", "\n", "- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n", "- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n", "\n", "Among these tools, the FSDL recommendation is\n", "[Weights & Biases](https://wandb.ai),\n", "which we believe offers\n", "- the best user experience, both in the Python SDKs and in the graphical interface\n", "- the best integrations with other tools,\n", "including\n", "[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n", "[Keras](https://docs.wandb.ai/guides/integrations/keras),\n", "[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n", "and even\n", "[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n", "and\n", "- the best tools for collaboration.\n", "\n", "Below, we'll take care to point out which logging and management features\n", "are available via generic interfaces in Lightning and which are W&B-specific." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "print(wandb.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding it to our experiment running code is extremely easy,\n", "relative to the features we get, which is\n", "one of the main selling points of W&B.\n", "\n", "We get most of our new experiment management features just by changing a single variable, `logger`, from\n", "`TensorboardLogger` to `WandbLogger`\n", "and adding two lines of code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see what each of these lines does for us below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this logger is built into and maintained by PyTorch Lightning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning.loggers import WandbLogger\n", "\n", "\n", "WandbLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to complete the rest of this notebook,\n", "you'll need a Weights & Biases account.\n", "\n", "As with GitHub the free tier, for personal, academic, and open source work,\n", "is very generous.\n", "\n", "The Text Recognizer project will fit comfortably within the free tier.\n", "\n", "Run the cell below and follow the prompts to log in or create an account or go\n", "[here](https://wandb.ai/signup)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb login" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the cell below to launch an experiment tracked with Weights & Biases.\n", "\n", "The experiment can take between 3 and 10 minutes to run.\n", "In that time, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n", " --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1\n", " \n", "last_expt = wandb.run\n", "\n", "wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see some new things in our output.\n", "\n", "For example, there's a note from `wandb` that the data is saved locally\n", "and also synced to their servers.\n", "\n", "There's a link to a webpage for viewing the logged data and a name for our experiment --\n", "something like `dandy-sunset-1`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local logging and cloud syncing happens with minimal impact on performance,\n", "because `wandb` launches a separate process to listen for events and upload them.\n", "\n", "That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Runs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, head to the link in the notebook output\n", "that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n", "\n", "There's no need to wait for training to finish.\n", "\n", "The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(last_expt.url)\n", "IFrame(last_expt.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have landed on the run page\n", "([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n", "which collects up all of the information for a single experiment into a collection of tabs.\n", "\n", "We'll work through these tabs from top to bottom.\n", "\n", "Each header is also a link to the documentation for a tab." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n", "This tab has an icon that looks like `(i)` or 🛈.\n", "\n", "The top section of this tab has high-level information about our run:\n", "- Timing information, like start time and duration\n", "- System hardware, hostname, and basic environment info\n", "- Git repository link and state\n", "\n", "This information is collected and logged automatically.\n", "\n", "The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n", "and summary metrics.\n", "\n", "Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n", "\n", "This tab has a line plot icon, something like 📈.\n", "\n", "It's also the default page you land on when looking at a W&B run.\n", "\n", "Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n", "\n", "Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n", "\n", "We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n", "This tab has computer chip icon.\n", "\n", "It contains\n", "- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n", "- CPU metrics: memory usage, utilization, thread counts\n", "- Disk and network I/O levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n", "This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n", "\n", "The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n", "This tab has an icon that looks like a stylized command prompt, `>_`.\n", "\n", "It contains information that was printed to the stdout.\n", "\n", "This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n", "\n", "Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelSummary\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n", "\n", "For more on Lightning `Callback`s, see\n", "[Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n", "This tab has a stylized document icon, something like 📄.\n", "\n", "You can use this tab to view any files saved with the `wandb.save`.\n", "\n", "For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n", "which we'll discuss shortly.\n", "\n", "But a few pieces of information automatically collected by W&B end up in this tab.\n", "\n", "Some highlights:\n", " - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n", " - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n", "This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n", "\n", "This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n", "\n", "We store two kinds of binary files\n", " - `run_table`s of model inputs and outputs\n", " - `model` checkpoints\n", "\n", "We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive Tables of Logged Media" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returning to the Charts tab,\n", "notice that we have model inputs and outputs logged in structured tables\n", "under the train, validation, and test sections.\n", "\n", "These tables are interactive as well\n", "([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n", "They support basic exploratory data analysis and are compatible with W&B's collaboration features." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to charts in our run page, these tables also have their own pages inside the W&B web app." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n", "table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n", "\n", "print(table_data_url)\n", "IFrame(src=table_data_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting this to work requires more effort and more W&B-specific code\n", "than the other features we've seen so far.\n", "\n", "We'll briefly explain the implementation here, for those who are interested.\n", "\n", "We use a custom Lightning `Callback`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n", "\n", "\n", "ImageToTextTableLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n", "\n", "The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n", "\n", "This behavior is sensible for metrics, which are low overhead, but not so much for media,\n", "where we'd rather subsample and avoid holding on to too much information.\n", "\n", "So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n", "\n", "The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.lit_models.base import BaseImageToTextLitModel\n", "\n", "BaseImageToTextLitModel.add_on_logged_batches??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Projects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Everything we've seen so far has been related to a single run or experiment.\n", "\n", "Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n", "\n", "We organize our runs into \"projects\" and view them on the W&B \"project page\" \n", "([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n", "\n", "By default in the Lightning integration, the project name is determined based on directory information.\n", "This default can be over-ridden in the code when creating a `WandbLogger`,\n", "but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what the project page looks like for a longer-running project with lots of experiments.\n", "\n", "The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n", "\n", "print(project_url)\n", "IFrame(src=project_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n", "\n", "We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Artifacts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Beyond logging metrics and metadata from runs,\n", "we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below pulls up all of the artifacts associated with the experiment we just ran." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Click on one of the `model` checkpoints -- the specific version doesn't matter.\n", "\n", "There are a number of tabs here.\n", "\n", "The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n", "\n", "The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n", "which are added by default by the `WandbLogger`.\n", "\n", "The \"Files\" tab contains the actual file contents of the artifact.\n", "\n", "On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n", "including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n", "\n", "You can click on these to explore the different versions and even directly compare them.\n", "\n", "If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Artifact storage is part of the W&B free tier.\n", "\n", "The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n", "\n", "The former is sufficient to store ~700 model checkpoints for the Text Recognizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can track your data storage and compare it to your limits at this URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n", "\n", "print(storage_tracker_url)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Programmatic Access" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also programmatically access our data and metadata via the `wandb` API\n", "([docs](https://docs.wandb.ai/guides/track/public-api-guide)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wb_api = wandb.Api()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run = wb_api.run(\"/\".join( # fetch a run given\n", " [last_expt.entity, # the user or org it was logged to\n", " last_expt.project, # the \"project\", usually one of several per repo/application\n", " last_expt.id] # and a unique ID\n", "))\n", "\n", "hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n", "\n", "hist.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hist.groupby(\"epoch\")[\"train/loss\"].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this includes the artifacts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# which artifacts where created and logged?\n", "artifacts = run.logged_artifacts()\n", "\n", "for artifact in artifacts:\n", " print(f\"artifact of type {artifact.type}: {artifact.name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thanks to our `ImageToTextTableLogger`,\n", "we can easily recreate training or validation data that came out of our `DataLoader`s,\n", "which is normally ephemeral:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n", "artifact_dir = Path(artifact.download(root=\"training/logs\"))\n", "image_dir = artifact_dir / \"media\" / \"images\"\n", "\n", "images = [path for path in image_dir.iterdir()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "from IPython.display import Image\n", "\n", "Image(str(random.choice(images)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Advanced W&B API Usage: MLOps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the strengths of a well-instrumented experiment tracking system is that it allows\n", "automatic relation of information:\n", "what were the inputs when this model's gradient spiked?\n", "Which models have been trained on this dataset,\n", "and what was their performance?\n", "\n", "Having access and automation around this information is necessary for \"MLOps\",\n", "which applies contemporary DevOps principles to ML projects." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cells below pull down the training data\n", "for the model currently running the FSDL Text Recognizer app.\n", "\n", "This is just intended as a demonstration of what's possible,\n", "so don't worry about understanding every piece of this,\n", "and feel free to skip past it.\n", "\n", "MLOps is still a nascent field, and these tools and workflows are likely to change.\n", "\n", "For example, just before the course launched, W&B released a\n", "[Model Registry layer](https://docs.wandb.ai/guides/models)\n", "on top of artifact logging that aims to improve the developer experience for these workflows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start from the same project we looked at in the project view:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n", "\n", "text_recognizer_project " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we search it for the text recognizer model currently being used in production:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# collect all versions of the text-recognizer ever put into production by...\n", "\n", "for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n", " if art_type.name == \"prod-ready\": # for the prod-ready type\n", " # and grabbing the text-recognizer\n", " production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n", "\n", "# and then get the one that's currently being tested in CI by...\n", "for text_recognizer in production_text_recognizers:\n", " if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n", " in_prod_text_recognizer = text_recognizer\n", "\n", "# view its metadata at the url or in the notebook\n", "in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n", "\n", "print(in_prod_text_recognizer_url)\n", "IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From its metadata, we can get information about how it was \"staged\" to be put into production,\n", "and in particular which model checkpoint was used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "staging_run = in_prod_text_recognizer.logged_by()\n", "\n", "training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n", "training_ckpt.name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That checkpoint was logged by a training experiment, which is available as metadata.\n", "\n", "We can look at the training run for that model, either here in the notebook or at its URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "training_run = training_ckpt.logged_by()\n", "print(training_run.url)\n", "IFrame(src=training_run.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And from there, we can access logs and metadata about training,\n", "confident that we are working with the model that is actually in production.\n", "\n", "For example, we can pull down the data we logged and analyze it locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_results = training_run.history(samples=10000)\n", "training_results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n", "training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 10\n", "training_results[\"validation/loss\"].dropna().iloc[10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The charts and webpages in Weights & Biases\n", "are substantially more useful than ephemeral stdouts or raw logs on disk.\n", "\n", "If you're spun up on the project,\n", "they accelerate debugging, exploration, and discovery.\n", "\n", "If not, they're not so much useful as they are overwhelming.\n", "\n", "We need to synthesize the raw logged data into information.\n", "This helps us communicate our work with other stakeholders,\n", "preserve knowledge and prevent repetition of work,\n", "and surface insights faster.\n", "\n", "These workflows are supported by the W&B Reports feature\n", "([docs here](https://docs.wandb.ai/guides/reports)),\n", "which mix W&B charts and tables with explanatory markdown text and embeds.\n", "\n", "Below are some common report patterns and\n", "use cases and examples of each." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the examples are from the FSDL Text Recognizer project.\n", "You can find more of them\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n", "where we've organized them into a report!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dashboard Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dashboards are a structured subset of the output from one or more experiments,\n", "designed for quickly surfacing issues or insights,\n", "like an accuracy or performance regression\n", "or a change in the data distribution.\n", "\n", "Use cases:\n", "- show the basic state of ongoing experiment\n", "- compare one experiment to another\n", "- select the most important charts so you can spin back up into context on a project more quickly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n", "\n", "IFrame(src=dashboard_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull Request Documentation Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In most software codebases,\n", "pull requests are a key focal point\n", "for units of work that combine\n", "short-term communication and long-term information tracking.\n", "\n", "In ML codebases, it's more difficult to bring\n", "sufficient information together to make PRs as useful.\n", "At FSDL, we like to add documentary\n", "reports with one or a small number of charts\n", "that connect logged information in the experiment management system\n", "to state in the version control software.\n", "\n", "Use cases:\n", "- communication of results within a team, e.g. code review\n", "- record-keeping that links pull request pages to raw logged info and makes it discoverable\n", "- improving confidence in PR correctness" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n", "\n", "IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Blog Post Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With sufficient effort, the logged data in the experiment management system\n", "can be made clear enough to be consumed,\n", "sufficiently contextualized to be useful outside the team, and\n", "even beautiful.\n", "\n", "The result is a report that's closer to a blog post than a dashboard or internal document.\n", "\n", "Use cases:\n", "- communication between teams or vertically in large organizations\n", "- external technical communication for branding and recruiting\n", "- attracting users or contributors\n", "\n", "Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma)\n", "and others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n", "\n", "IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many of our choices, like the depth of our network, the nonlinearities of our layers,\n", "and the learning rate and other parameters of our optimizer, cannot be\n", "([easily](https://arxiv.org/abs/1606.04474))\n", "chosen by descent of the gradient of a loss function.\n", "\n", "But these parameters that impact the values of the parameters\n", "we directly optimize with gradients, or _hyperparameters_,\n", "can still be optimized,\n", "essentially by trying options and selecting the values that worked best.\n", "\n", "In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n", "\n", "Expending more compute can squeeze small amounts of additional validation or test performance\n", "that makes for impressive results on leaderboards but typically doesn't translate\n", "into better user experience.\n", "\n", "In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n", "built into your other tooling.\n", "\n", "Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n", "([docs](https://docs.wandb.ai/guides/sweeps)).\n", "\n", "It also supports a number of more advanced tools, like\n", "[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n", "for early termination of poorly-performing runs.\n", "\n", "We can use the same training script and we don't need to run an optimization server.\n", "\n", "We just need to write a configuration yaml file\n", "([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n", "like the one below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training/simple-overfit-sweep.yaml\n", "# first we specify what we're sweeping\n", "# we specify a program to run\n", "program: training/run_experiment.py\n", "# we optionally specify how to run it, including setting default arguments\n", "command: \n", " - ${env}\n", " - ${interpreter}\n", " - ${program}\n", " - \"--wandb\"\n", " - \"--overfit_batches\"\n", " - \"1\"\n", " - \"--log_every_n_steps\"\n", " - \"25\"\n", " - \"--max_epochs\"\n", " - \"100\"\n", " - \"--limit_test_batches\"\n", " - \"0\"\n", " - ${args} # these arguments come from the sweep parameters below\n", "\n", "# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n", "method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n", "metric:\n", " name: train/loss\n", " goal: minimize\n", "parameters: \n", " # LineCNN hyperparameters\n", " window_width:\n", " values: [8, 16, 32, 64]\n", " window_stride:\n", " values: [4, 8, 16, 32]\n", " # Transformer hyperparameters\n", " tf_layers:\n", " values: [1, 2, 4, 8]\n", " # we can also fix some values, just like we set default arguments\n", " gpus:\n", " value: 1\n", " model_class:\n", " value: LineCNNTransformer\n", " data_class:\n", " value: IAMLines\n", " loss:\n", " value: transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the config we launch a \"controller\":\n", "a lightweight process that just decides what hyperparameters to try next\n", "and coordinates the heavierweight training.\n", "\n", "This lives on the W&B servers, so there are no headaches about opening ports for communication,\n", "cleaning up when it's done, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n", "simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we can launch an \"agent\" to follow the orders of the controller:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "\n", "# interrupt twice to terminate this cell if it's running too long,\n", "# it can be over 15 minutes with some hyperparameters\n", "\n", "!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n", "\n", "If not provided, the agent will run forever for random or Bayesian sweeps\n", "or until the sweep is terminated, which can be done from the W&B interface." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The agents make for a slick workflow for distributing sweeps across GPUs.\n", "\n", "We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n", "which controls which GPUs are accessible by a process, to launch\n", "parallel agents on separate GPUs on the same machine." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n", "# open another terminal\n", "CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n", "# and so on\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include optional exercises with the labs for learners who want to dive deeper on specific topics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟Contribute to a hyperparameter search." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n", "\n", "There are ~10,000,000 potential hyperparameter combinations,\n", "and each takes 30 minutes to test,\n", "so checking each possibility will take over 500 years of compute time.\n", "Best get cracking then!\n", "\n", "Run the cell below to pull up a dashboard and print the URL where you can check on the current status." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_entity = \"fullstackdeeplearning\"\n", "sweep_project = \"fsdl-line-recognizer-2022\"\n", "sweep_id = \"e0eo43eu\"\n", "sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n", "\n", "print(sweep_url)\n", "IFrame(src=sweep_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve information about the sweep from the API,\n", "including the hyperparameters being swept over." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparams = sweep_info.config[\"parameters\"]\n", "hyperparams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you'd like to contribute to this sweep,\n", "run the cell below after changing the count to a number greater than 0.\n", "\n", "Each iteration runs for 30 minutes if it does not crash,\n", "e.g. due to out-of-memory errors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "count = 0 # off by default, increase it to join in!\n", "\n", "if count:\n", " !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Write some manual logging in `wandb`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the FSDL Text Recognizer codebase,\n", "we almost exclusively log to W&B through Lightning,\n", "rather than through the `wandb` Python SDK.\n", "\n", "If you're interested in learning how to use W&B directly, e.g. with another training framework,\n", "try out this quick exercise that introduces the key players in the SDK." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n", "\n", "It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n", "\n", "Add W&B metric and artifact logging to this cell:\n", "- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n", "- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import random\n", "\n", "import wandb\n", "\n", "\n", "os.makedirs(\"logs\", exist_ok=True)\n", "\n", "project = \"trying-wandb\"\n", "config = {\"steps\": 50}\n", "\n", "\n", "with wandb.init(project=project, config=config) as run:\n", " steps = wandb.config[\"steps\"]\n", " \n", " for ii in range(steps):\n", " loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n", " \n", " with open(\"logs/hello.txt\", \"w\") as f:\n", " f.write(\"hello from wandb, my dudes!\")\n", " \n", " run_id = run.id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hello_run = wb_api.run(f\"{project}/{run_id}\")\n", "\n", "# check for logged loss data\n", "if \"loss\" not in hello_run.history().keys():\n", " print(\"loss not logged 🥲\")\n", "else:\n", " print(\"loss logged successfully 🥞\")\n", " if len(hello_run.history()[\"loss\"]) != steps:\n", " print(\"loss not logged on all steps 🥲\")\n", " else:\n", " print(\"loss logged on all steps 🥞\")\n", "\n", "artifacts = hello_run.logged_artifacts()\n", "\n", "# check for artifact with the right name\n", "if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n", " print(\"hello artifact not logged 🥲\")\n", "else:\n", " print(\"hello artifact logged successfully 🥞\")\n", " # check for the file inside the artifacts\n", " if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n", " print(\"could not find hello.txt 🥲\")\n", " else:\n", " print(\"hello.txt logged successfully 🥞\")\n", " \n", " \n", "hello_run" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you observe interesting phenomena during training,\n", "from promising hyperparameter combos to software bugs to strange model behavior,\n", "turn the charts into a W&B report and share it with the FSDL community or\n", "[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n", "with a link to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# check the sweep_info.config above to see the model and data hyperparameters\n", "# read through the --help output for all potential arguments\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n", " --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n", " --help # remove this line to run an experiment instead of printing help\n", " \n", "last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟🌟🌟 Add logging of tensor statistics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to logging model inputs and outputs as human-interpretable media,\n", "it's also frequently useful to see information about their numerical values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're interested in learning more about metric calculation and logging with Lightning,\n", "use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n", "to add tensor statistic logging to the `LineCNNTransformer`.\n", "\n", "`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n", "\n", "All three are useful, but start by adding just one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n", "- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n", " - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n", "- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n", " - Base your code on the calculation and logging of the `val_cer` metric.\n", " - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ", "collapsed_sections": [], "name": "lab04_experiments.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab07/notebooks/lab05_troubleshooting.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 05: Troubleshooting & Testing" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n", "- How to implement tests for ML training systems in particular\n", "- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 5\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sThWeTtV6fL_" }, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "xFP8lU4nSg1P" }, "source": [ "# Linting Python and Shell Scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "cXbdYfFlPhZ-" }, "source": [ "### Automatically linting with `pre-commit`" ] }, { "cell_type": "markdown", "metadata": { "id": "ysqqb2GjvLrz" }, "source": [ "We want keep our code clean and uniform across developers\n", "and time.\n", "\n", "Applying the cleanliness checks and style rules should be\n", "as painless and automatic as possible.\n", "\n", "For this purpose, we recommend bundling linting tools together\n", "and enforcing them on all commits with\n", "[`pre-commit`](https://pre-commit.com/)." ] }, { "cell_type": "markdown", "metadata": { "id": "XvqtZChKvLr0" }, "source": [ "In addition to running on every commit,\n", "`pre-commit` separates the model development environment from the environments\n", "needed for the linting tools, preventing conflicts\n", "and simplifying maintenance and onboarding." ] }, { "cell_type": "markdown", "metadata": { "id": "Y0XuIuKOXhJl" }, "source": [ "This cell runs `pre-commit`.\n", "\n", "The first time it is run on a machine, it will install the environments for all tools." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hltYGbpNvLr1" }, "outputs": [], "source": [ "!pre-commit run --all-files" ] }, { "cell_type": "markdown", "metadata": { "id": "gLw08gIkvLr1" }, "source": [ "The output lists all the checks that are run and whether they are passed.\n", "\n", "Notice there are a number of simple version-control hygiene practices included\n", "that aren't even specific to Python, much less to machine learning.\n", "\n", "For example, several of the checks prevent accidental commits with private keys, large files, \n", "leftover debugger statements, or merge conflict annotations in them." ] }, { "cell_type": "markdown", "metadata": { "id": "RHEEjb9kvLr1" }, "source": [ "These linting actions are configured via\n", "([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n", "a YAML file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dgXa8BzrvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml" ] }, { "cell_type": "markdown", "metadata": { "id": "8HYc_WbTvLr2" }, "source": [ "Most of the general cleanliness checks are from hooks built by `pre-commit`.\n", "\n", "See the comments and links in the `.pre-commit-config.yaml` for more:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K9rTgRqzvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep repos -A 15" ] }, { "cell_type": "markdown", "metadata": { "id": "1ptkO7aPvLr2" }, "source": [ "Let's take a look at the section of the file\n", "that applies most of our Python style enforcement with\n", "[`flake8`](https://flake8.pycqa.org/en/latest/):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ALsRKfcevLr3", "scrolled": true }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10" ] }, { "cell_type": "markdown", "metadata": { "id": "a_Q0BwQUXbg6" }, "source": [ "The majority of the style checking behavior we want comes from the\n", "`additional_dependencies`, which are\n", "[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n", "that extend `flake8`'s list of lints.\n", "\n", "Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n", "\n", "We keep the configuration information for `flake8`\n", "separate from that for `pre-commit`\n", "in case we want to use additional tools with `flake8`,\n", "e.g. if some developers want to integrate it directly into their editor,\n", "and so that if we change away from `.pre-commit`\n", "but keep `flake8` we don't have to\n", "recreate our configuration in a different tool.\n", "\n", "As much as possible, codebases should strive for single sources of truth\n", "and link back to those sources of truth with documentation or comments,\n", "as in the last line above.\n", "\n", "Let's take a look at the contents of `flake8`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "doC_4WQwvLr3" }, "outputs": [], "source": [ "!cat .flake8" ] }, { "cell_type": "markdown", "metadata": { "id": "0Nq6HnyU0M47" }, "source": [ "There's a lot here! We'll focus on the most important bits." ] }, { "cell_type": "markdown", "metadata": { "id": "U4PiB8CPvLr3" }, "source": [ "Linting tools in Python generally work by emitting error codes\n", "with one or more letters followed by three numbers.\n", "The `select` argument picks which error codes we want to check for.\n", "Error codes are matched by prefix,\n", "so for example `B` matches `BTS101` and\n", "`G1` matches `G102` and `G199` but not `ARG404`.\n", "\n", "Certain codes are `ignore`d in the default `flake8` style,\n", "which is done via the `ignore` argument,\n", "and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n", "For example, we rely on `black` to do our formatting,\n", "so we ignore some of `flake8`'s formatting codes.\n", "\n", "Together, these settings define our project's particular style.\n", "\n", "But not every file fits this style perfectly.\n", "Most of the conventions in `black` and `flake8` come from the style-defining\n", "[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n", "which exhorts you to \"know when to be inconsistent\".\n", "\n", "To allow ourselves to be inconsistent when we know we should be,\n", "`flake8` includes `per-file-ignores`,\n", "which let us ignore specific warnings in specific files.\n", "This is one of the \"escape valves\"\n", "that makes style enforcement tolerable.\n", "We can also `exclude` files in the `pre-commit` config itself.\n", "\n", "For details on selecting and ignoring,\n", "see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n", "\n", "For definitions of the error codes from `flake8` itself,\n", "see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n", "Individual extensions list their added error codes in their documentation,\n", "e.g. `darglint` does so\n", "[here](https://github.com/terrencepreilly/darglint#error-codes)." ] }, { "cell_type": "markdown", "metadata": { "id": "NL0TpyPsvLr4" }, "source": [ "The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n", "\n", "You can read more about each in their documentation:\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "markdown", "metadata": { "id": "mFsZC0a7vLr4" }, "source": [ "### Linting via a script and using `shellcheck`" ] }, { "cell_type": "markdown", "metadata": { "id": "RYjpuFwjXkJc" }, "source": [ "To avoid needing to think about `pre-commit`\n", "(was the command `pre-commit run` or `pre-commit check`?)\n", "while developing locally,\n", "we might put our linters into a shell script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXlLFWmavLr4" }, "outputs": [], "source": [ "!cat tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "PPxHpRIB3nbw" }, "source": [ "These kinds of short and simple shell scripts are common in projects\n", "of intermediate size.\n", "\n", "They are useful for adding automation and reducing friction." ] }, { "cell_type": "markdown", "metadata": { "id": "TMuPBpAi2qwl" }, "source": [ "But these scripts are code,\n", "and all code is susceptible to bugs and subject to concerns of style consistency." ] }, { "cell_type": "markdown", "metadata": { "id": "SQRg3ZqXvLr4" }, "source": [ "We can't check these scripts with tools that lint Python code,\n", "so we include a shell script linting tool,\n", "[`shellcheck`](https://www.shellcheck.net/),\n", "in our `pre-commit`.\n", "\n", "More so than checking for correct style,\n", "this tool checks for common bugs or surprising behaviors of shells,\n", "which are unfortunately numerous." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zkfhE1srvLr4" }, "outputs": [], "source": [ "script_filename = \"tasks/lint.sh\"\n", "!pre-commit run shellcheck --files {script_filename}" ] }, { "cell_type": "markdown", "metadata": { "id": "KXU9TRrwvLr4" }, "source": [ "That script has already been tested, so we don't see any errors.\n", "\n", "Try copying over a script you've written yourself or\n", "even from a popular repo that you like\n", "(by adding to the notebook directory or by making a cell\n", "with `%%writefile` at the top)\n", "and test it by changing the `script_filename`.\n", "\n", "You'd be surprised at the classes of subtle bugs possible in bash!" ] }, { "cell_type": "markdown", "metadata": { "id": "81MhAL-TvLr5" }, "source": [ "### Try \"unofficial bash strict mode\" for louder failures in scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "hSwhs_zUvLr5" }, "source": [ "Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n", "[@redsymbol](https://twitter.com/redsymbol),\n", "which appear at the top of the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o-j0vSxEvLr5" }, "outputs": [], "source": [ "!head -n 3 tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "d2iJU5jlvLr5" }, "source": [ "The core idea of strict mode is to fail more loudly.\n", "This is a desirable behavior of scripts,\n", "like the ones we're writing,\n", "even though it's an undesirable behavior for an interactive shell --\n", "it would be unpleasant to be logged out every time you hit an error.\n", "\n", "`set -u` means scripts fail if a variable's value is `u`nset,\n", "i.e. not defined.\n", "Otherwise bash is perfectly happy to allow you to reference undefined variables.\n", "The result is just an empty string, which can lead to maddeningly weird behavior.\n", "\n", "`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n", "rather than using the exit code of the last command.\n", "Unix tools are perfectly happy to work on nonsense input,\n", "like sorting error messages, instead of the filenames you meant to send.\n", "\n", "You can read more about these choices\n", "[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n", "and considerations for working with other non-conforming scripts in \"strict mode\"\n", "and for handling resource teardown when scripts error out." ] }, { "cell_type": "markdown", "metadata": { "id": "s1XqsrU_XWWS" }, "source": [ "# Testing ML Codebases" ] }, { "cell_type": "markdown", "metadata": { "id": "CPNzeq3NYF2W" }, "source": [ "## Testing Python code with `pytests`" ] }, { "cell_type": "markdown", "metadata": { "id": "zq5e_x6gc9Vu" }, "source": [ "\n", "ML codebases are Python first and foremost, so first let's get some Python tests going." ] }, { "cell_type": "markdown", "metadata": { "id": "0DC3GxYz6_R9" }, "source": [ "At a basic level,\n", "we can write functions that `assert`\n", "that our code behaves as expected in\n", "a given scenario and include it in the same module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rvd-GNwv63W1" }, "outputs": [], "source": [ "from text_recognizer.lit_models.metrics import test_character_error_rate\n", "\n", "test_character_error_rate??" ] }, { "cell_type": "markdown", "metadata": { "id": "iVB2TsQS5BTq" }, "source": [ "The standard tool for testing Python code is\n", "[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n", "\n", "We can use it as a command-line tool in a variety of ways,\n", "including to execute these kinds of tests.\n", "\n", "If passed a filename, `pytest` will look for\n", "any classes that start with `Test` or\n", "any functions that start with `test_` and run them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u8sQguyJvLr6", "scrolled": false }, "outputs": [], "source": [ "!pytest text_recognizer/lit_models/metrics.py" ] }, { "cell_type": "markdown", "metadata": { "id": "92tkBCllvLr6" }, "source": [ "After the results of the tests (pass or fail) are returned,\n", "you'll see a report of \"coverage\" from\n", "[`codecov`](https://about.codecov.io/).\n", "\n", "This coverage report tells us which files and how many lines in those files\n", "were at touched by the testing suite." ] }, { "cell_type": "markdown", "metadata": { "id": "PllSUe0s5xvU" }, "source": [ "We do not actually need to provide the names of files with tests in them to `pytest`\n", "in order for it to run our tests." ] }, { "cell_type": "markdown", "metadata": { "id": "4qOBHJnTZM9x" }, "source": [ "By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n", "\n", "It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n", "to separate these from the rest of your code\n", "in a folder or folders named `tests`,\n", "rather than scattering them around the repo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "acjsYTNSvLr6" }, "outputs": [], "source": [ "!ls text_recognizer/tests" ] }, { "cell_type": "markdown", "metadata": { "id": "WZQQZUF0vLr6" }, "source": [ "Let's take a look at a specific example:\n", "the tests for some of our utilities around\n", "custom PyTorch Lightning `Callback`s." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oS0xKv1evLr6" }, "outputs": [], "source": [ "from text_recognizer.tests import test_callback_utils\n", "\n", "\n", "test_callback_utils.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "lko8msn-vLr7" }, "source": [ "Notice that we can easily import this as a module!\n", "\n", "That's another benefit of organizing tests into specialized files." ] }, { "cell_type": "markdown", "metadata": { "id": "5A85FUNv75Fr" }, "source": [ "The particular utility we're testing\n", "here is designed to prevent crashes:\n", "it checks for a particular type of error and turns it into a warning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jl4-DiVe76sw" }, "outputs": [], "source": [ "from text_recognizer.callbacks.util import check_and_warn\n", "\n", "check_and_warn??" ] }, { "cell_type": "markdown", "metadata": { "id": "B6E0MhduvLr7" }, "source": [ "Error-handling code is a common cause of bugs,\n", "a fact discovered\n", "[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n", "so it's very important to test it well!\n", "\n", "We start with a very basic test,\n", "which does not touch anything\n", "outside of the Python standard library,\n", "even though this tool is intended to be used\n", "with more complex features of third-party libraries,\n", "like `wandb` and `tensorboard`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xx5koQmJvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_simple??" ] }, { "cell_type": "markdown", "metadata": { "id": "MZe9-JVjvLr7" }, "source": [ "Here, we are just testing the core logic.\n", "This test won't catch many bugs,\n", "but when it does fail, something has gone seriously wrong.\n", "\n", "These kinds of tests are important for resolving a bug:\n", "we learn nearly as much from the tests that passed\n", "as we did from the tests that failed.\n", "If this test has failed, possibly along with others,\n", "we can rule out an issue in one of the large external codebases\n", "touched in the other tests, saving us lots of time in our troubleshooting.\n", "\n", "The reasoning for the test is explained in the docstrings, \n", "which are close to the code.\n", "\n", "Your test suite should be as welcoming\n", "as the rest of your codebase!\n", "The people reading it, for example yourself in six months, \n", "are likely upset and in need of some kindness.\n", "\n", "More practically, we want keep our time to resolve errors as short as possible,\n", "and five minutes to write a good docstring now\n", "can save five minutes during an outage, when minutes really matter." ] }, { "cell_type": "markdown", "metadata": { "id": "Om9k-uXhvLr7" }, "source": [ "That basic test is a start, but it's not enough by itself.\n", "There's a specific error case that triggered the addition of this code.\n", "\n", "So we test that it's handled as expected." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjbsb5FvvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_tblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "CGAIZTUjvLr7" }, "source": [ "That test can fail if the libraries change around our code,\n", "i.e. if the `TensorBoardLogger` gets a `log_table` method.\n", "\n", "We want to be careful when making assumptions\n", "about other people's software,\n", "especially for fast-moving libraries like Lightning.\n", "If we test that those assumptions hold willy-nilly,\n", "we'll end up with tests that fail because of\n", "harmless changes in our dependencies.\n", "\n", "Tests that require a ton of maintenance and updating\n", "without leading to code improvements soak up\n", "more engineering time than they save\n", "and cause distrust in the testing suite.\n", "\n", "We include this test because `TensorBoardLogger` getting\n", "a `log_table` method will _also_ change the behavior of our code\n", "in a breaking way, and we want to catch that before it breaks\n", "a model training job." ] }, { "cell_type": "markdown", "metadata": { "id": "jsy95KAvvLr7" }, "source": [ "Adding error handling can also accidentally kill the \"happy path\"\n", "by raising an error incorrectly.\n", "\n", "So we explicitly test the _absence of an error_,\n", "not just its presence:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LRlIOkjmvLr8" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_wandblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "osiqpLynvLr8" }, "source": [ "There are more tests we could build, e.g. manipulating classes and testing the behavior,\n", "testing more classes that might be targeted by `check_and_warn`, or\n", "asserting that warnings are raised to the command line.\n", "\n", "But these three basic tests are likely to catch most changes that would break our code here,\n", "and they're a lot easier to write than the others.\n", "\n", "If this utility starts to get more usage and become a critical path for lots of features, we can always add more!" ] }, { "cell_type": "markdown", "metadata": { "id": "dm285JE5vLr8" }, "source": [ "## Interleaving testing and documentation with `doctests`" ] }, { "cell_type": "markdown", "metadata": { "id": "UHWQvgA8vLr8" }, "source": [ "One function of tests is to build user/reader confidence in code." ] }, { "cell_type": "markdown", "metadata": { "id": "wrhiJBXFvLr8" }, "source": [ "One function of documentation is to build user/reader knowledge in code." ] }, { "cell_type": "markdown", "metadata": { "id": "1vu12LDhvLr8" }, "source": [ "These functions are related. Let's put them together:\n", "put code in a docstring and test that code.\n", "\n", "This feature is part of the\n", "Python standard library via the\n", "[`doctest` module](https://docs.python.org/3/library/doctest.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "rmfIOwXd-Qt7" }, "source": [ "Here's an example from our `torch` utilities.\n", "\n", "The `first_appearance` function can be used to\n", "e.g. quickly look for stop tokens,\n", "giving the length of each sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZzURGcD9vLr8" }, "outputs": [], "source": [ "from text_recognizer.lit_models.util import first_appearance\n", "\n", "\n", "first_appearance??" ] }, { "cell_type": "markdown", "metadata": { "id": "0VtYcJ1WvLr8" }, "source": [ "Notice that in the \"Examples\" section,\n", "there's a short block of code formatted as a\n", "Python interpreter session,\n", "complete with outputs.\n", "\n", "We can copy and paste that code and\n", "check that we get the right outputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dj4lNOxJvLr9" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y9AWHFoIvLr9" }, "source": [ "We can run the test with `pytest` by passing a command line argument,\n", "`--doctest-modules`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JMaAxv5ovLr9" }, "outputs": [], "source": [ "!pytest --doctest-modules text_recognizer/lit_models/util.py" ] }, { "cell_type": "markdown", "metadata": { "id": "6-2_aOUfvLr9" }, "source": [ "With the\n", "[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n", "running `doctest`s happens automatically\n", "when `pytest` is invoked." ] }, { "cell_type": "markdown", "metadata": { "id": "my_keokPvLr9" }, "source": [ "## Basic tests for data code" ] }, { "cell_type": "markdown", "metadata": { "id": "Qj3Bq_j2_A8o" }, "source": [ "ML code can be hard to test\n", "since it involes very heavy artifacts, like models and data,\n", "and very expensive jobs, like training." ] }, { "cell_type": "markdown", "metadata": { "id": "DT5OmgrQvLr9" }, "source": [ "For testing our data-handling code in the FSDL codebase,\n", "we mostly just use `assert`s,\n", "which throw errors when behavior differs from expectation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bdzn5g4TvLr9" }, "outputs": [], "source": [ "!grep \"assert\" -r text_recognizer/data" ] }, { "cell_type": "markdown", "metadata": { "id": "2aTlfu4_vLr-" }, "source": [ "This isn't great practice,\n", "especially as a codebase grows,\n", "because we can't easily know when these are executed\n", "or incorporate them into\n", "testing automation and coverage analysis tools." ] }, { "cell_type": "markdown", "metadata": { "id": "IaMTdmbZ_mkW" }, "source": [ "So it's preferable to collect up these assertions of simple data properties\n", "into tests that are run like our other tests.\n", "\n", "The test below checks whether any data is leaking\n", "between training, validation, and testing." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qx7cxiDdvLr-" }, "outputs": [], "source": [ "from text_recognizer.tests.test_iam import test_iam_data_splits\n", "\n", "\n", "test_iam_data_splits??" ] }, { "cell_type": "markdown", "metadata": { "id": "16TJwhd1vLr-" }, "source": [ "Notice that we were able to load the test into the notebook\n", "because it is in a module,\n", "and so we can run it here as well:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mArITFkYvLr-" }, "outputs": [], "source": [ "test_iam_data_splits()" ] }, { "cell_type": "markdown", "metadata": { "id": "E4F2uaclvLr-" }, "source": [ "But we're checking something pretty simple here,\n", "so the new code in each test is just a single line.\n", "\n", "What if we wanted to test more complex properties,\n", "like comparing rows or calculating statistics?\n", "\n", "We'll end up writing more complex code that might itself have subtle bugs,\n", "requiring tests for our tests and suffering from\n", "\"tester's regress\".\n", "\n", "This is the phenomenon,\n", "named by analogy with\n", "[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n", "in sociology of science,\n", "where the validity of our tests is itself\n", "up for dispute only resolvable by testing the tests,\n", "but those tests are themselves possibly invalid." ] }, { "cell_type": "markdown", "metadata": { "id": "nUGT06gdvLr-" }, "source": [ "We cut this Gordian knot by using\n", "a library or framework that is well-tested.\n", "\n", "We recommend checking out\n", "[`great_expectations`](https://docs.greatexpectations.io/docs/)\n", "if you're looking for a high-quality data testing tool." ] }, { "cell_type": "markdown", "metadata": { "id": "dQ5vNsq3vLr-" }, "source": [ "Especially with data, some tests are particularly \"heavy\" --\n", "they take a long time,\n", "and we might want to run them\n", "on different machines\n", "and on a different schedule\n", "than our other tests." ] }, { "cell_type": "markdown", "metadata": { "id": "xephcb0LvLr-" }, "source": [ "For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n", "\n", "We can't just use a cached version of the data,\n", "since that won't actually execute the code!\n", "\n", "This test will take\n", "as long to run\n", "and consume as many resources as\n", "a full download of the data." ] }, { "cell_type": "markdown", "metadata": { "id": "YSN4w2EqvLr-" }, "source": [ "`pytest` allows the separation of tests\n", "into suites with `mark`s,\n", "which \"tag\" tests with names." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V0rScrcXvLr_", "scrolled": false }, "outputs": [], "source": [ "!pytest --markers | head -n 10" ] }, { "cell_type": "markdown", "metadata": { "id": "lr5Ca7B0vLr_" }, "source": [ "We can choose to run tests with a given mark\n", "or to skip tests with a given mark, \n", "among other basic logical operations around combining and filtering marks,\n", "with `-m`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xmw-Eb1ZvLr_" }, "outputs": [], "source": [ "!wandb login # one test requires wandb authentication\n", "\n", "!pytest -m \"not data and not slow\"" ] }, { "cell_type": "markdown", "metadata": { "id": "5LuERxOXX_UJ" }, "source": [ "## Testing training with memorization tests" ] }, { "cell_type": "markdown", "metadata": { "id": "AnWLN4lRvLsA" }, "source": [ "Training is the process by which we convert inert data into executable models,\n", "so it is dependent on both.\n", "\n", "We decouple checking whether the script has a critical bug\n", "from whether the data or model code is broken\n", "by testing on some basic \"fake data\",\n", "based on a utility from `torchvision`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k4NIc3uWvLsA" }, "outputs": [], "source": [ "from text_recognizer.data import FakeImageData\n", "\n", "\n", "FakeImageData.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "deN0swwlvLsA" }, "source": [ "We then test on the actual data with a smaller version of the real model.\n", "\n", "We use the Lightning `--fast_dev_run` feature,\n", "which sets the number of training, validation, and test batches to `1`.\n", "\n", "We use a smaller version so that this test can run in just a few minutes\n", "on a CPU without acceleration.\n", "\n", "That allows us to run our tests in environments without GPUs,\n", "which saves on costs for executing tests.\n", "\n", "Here's the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z4J0_uD9vLsA" }, "outputs": [], "source": [ "!cat training/tests/test_run_experiment.sh" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-7u9zS1vLsA", "scrolled": false }, "outputs": [], "source": [ "! ./training/tests/test_run_experiment.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "UTzfo11KClV3" }, "source": [ "The above tests don't actaully check\n", "whether any learning occurs,\n", "they just check\n", "whether training runs mechanically,\n", "without any errors.\n", "\n", "We also need a\n", "[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n", "for learning.\n", "For that we recommending checking whether\n", "the model can learn the right\n", "outputs for a single batch --\n", "to \"memorize\" the outputs for\n", "a particular input.\n", "\n", "This memorization test won't\n", "catch every bug or issue in training,\n", "which is notoriously difficult,\n", "but it will flag\n", "some of the most serious issues." ] }, { "cell_type": "markdown", "metadata": { "id": "0DVSp3aAvLsA" }, "source": [ "The script below runs a memorization test." ] }, { "cell_type": "markdown", "metadata": { "id": "2DFVVrxpvLsA" }, "source": [ "It takes up to two arguments:\n", "a `MAX`imum number of `EPOCHS` to run for and\n", "a `CRITERION` value of the loss to test against.\n", "\n", "The test passes if the loss is lower than the `CRITERION` value\n", "after the `MAX`imum number of `EPOCHS` has passed." ] }, { "cell_type": "markdown", "metadata": { "id": "oEhJH0e5vLsB" }, "source": [ "The important line in this script is the one that invokes our training script,\n", "`training/run_experiment.py`.\n", "\n", "The arguments to `run_experiment` have been tuned for maximum possible speed:\n", "turning off regularization, shrinking the model,\n", "and skipping parts of Lightning that we don't want to test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T-fFs1xEvLsB" }, "outputs": [], "source": [ "!cat training/tests/test_memorize_iam.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "X-47tUA_YNGe" }, "source": [ "If you'd like to see what a memorization run looks like,\n", "flip the `running_memorization` flag to `True`\n", "and watch the results stream in to W&B.\n", "\n", "The cell should run in about ten minutes on a commodity GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GwTEsZwKvLsB" }, "outputs": [], "source": [ "%%time\n", "running_memorization = False\n", "\n", "if running_memorization:\n", " max_epochs = 1000\n", " loss_criterion = 0.05\n", " !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Troubleshooting model speed with the PyTorch Profiler" ] }, { "cell_type": "markdown", "metadata": { "id": "DpbN-Om2Drf-" }, "source": [ "Testing code is only half the story here:\n", "we also need to fix the issues that our tests flag.\n", "This is the process of troubleshooting.\n", "\n", "In this lab,\n", "we'll focus on troubleshooting model performance issues:\n", "what do to when your model runs too slowly." ] }, { "cell_type": "markdown", "metadata": { "id": "NZzwELPXvLsD" }, "source": [ "Troubleshooting deep neural networks for speed is challenging.\n", "\n", "There are at least three different common approaches,\n", "each with an increasing level of skill required:\n", "\n", "1. Follow best practices advice from others\n", "([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n", "[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n", "2. Take code that runs slowly and use empirical observations to iteratively improve it.\n", "3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n", "\n", "For the full stack deep learning engineer,\n", "the final level is typically out of reach,\n", "unless you're specializing in the model performance\n", "part of the stack in particular.\n", "\n", "So we recommend reaching the middle level,\n", "and this segment of the lab walks through the\n", "tools that make this easier." ] }, { "cell_type": "markdown", "metadata": { "id": "3_yp87UrFZ8M" }, "source": [ "Because neural network training involves GPU acceleration,\n", "generic Python profiling tools like\n", "[`py-spy`](https://github.com/benfred/py-spy)\n", "won't work, and\n", "we'll need tools specialized for tracing and profiling DNN training." ] }, { "cell_type": "markdown", "metadata": { "id": "yspsYVFGEyZm" }, "source": [ "In general, these tools are for observing what happens while your code is executing:\n", "_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n", "\n", "Because they help us observe the execution in detail,\n", "they will also help us understand just what is going on during\n", "a PyTorch training step in greater detail." ] }, { "cell_type": "markdown", "metadata": { "id": "YqXq2hKuvLsE" }, "source": [ "To support profiling and tracing,\n", "we've added a new argument to `training/run_experiment.py`, `--profile`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z_GMMViWvLsE" }, "outputs": [], "source": [ "!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\"" ] }, { "cell_type": "markdown", "metadata": { "id": "ZldoksHPvLsE" }, "source": [ "As with experiment management, this relies mostly on features of PyTorch Lightning,\n", "which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n", "and we just add a few lines of customization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F2iJ0_A6vLsE" }, "outputs": [], "source": [ "!cat training/run_experiment.py | grep args.profile -A 5" ] }, { "cell_type": "markdown", "metadata": { "id": "Aw3ppgndvLsE" }, "source": [ "For more on profiling with Lightning, see the\n", "[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCAmNW3QEtcD" }, "source": [ "The cell below runs an epoch of training with tracing and profiling turned on\n", "and then saves the results locally and to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t4o3ylDgr46F", "scrolled": false }, "outputs": [], "source": [ "import glob\n", "\n", "import torch\n", "import wandb\n", "\n", "from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n", "\n", "\n", "# make it easier to separate these from training runs\n", "%env WANDB_JOB_TYPE=profile\n", "\n", "batch_size = 16\n", "num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n", "gpus = 1 # must be run with accelerator\n", "\n", "%run training/run_experiment.py --wandb --profile \\\n", " --max_epochs=1 \\\n", " --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n", " --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n", " --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n", "\n", "latest_expt = wandb.run\n", "\n", "try: # add execution trace to logged and versioned binaries\n", " folder = wandb.run.dir\n", " trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n", " trace_file = glob.glob(trace_matcher)[0]\n", " trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n", " trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n", " wandb.log_artifact(trace_at)\n", "except IndexError:\n", " print(\"trace not found\")\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": { "id": "ePTkS3EqO5tN" }, "source": [ "We get out a table of statistics in the terminal,\n", "courtesy of Lightning.\n", "\n", "Each row lists an operation\n", "and and provides information,\n", "described in the column headers,\n", "about the time spent on that operation\n", "across all the training steps we profiled.\n", "\n", "With practice, some useful information can be read out from this table,\n", "but it's better to start from both a less detailed view,\n", "in the TensorBoard dashboard,\n", "and a more detailed view,\n", "using the Chrome Trace viewer." ] }, { "cell_type": "markdown", "metadata": { "id": "TzV62f3c7-Bi" }, "source": [ "## High-level statistics from the PyTorch Profiler in TensorBoard" ] }, { "cell_type": "markdown", "metadata": { "id": "mNPKXkYw8NWd" }, "source": [ "Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CbItwuT88eAV" }, "outputs": [], "source": [ "your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n", "\n", "print(your_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "jE_LooMYHFpF" }, "source": [ "If at any point you run into issues,\n", "like the description not matching what you observe,\n", "check out one of our example runs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "za2zybSwIo5C" }, "outputs": [], "source": [ "example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n", "print(example_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "xlrhl1n4HYU6" }, "source": [ "Once the TensorBoard session has loaded up,\n", "we are dropped into the Overview\n", "(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n", "for an example).\n", "\n", "In the top center, we see the **GPU Summary** for our system.\n", "\n", "In addition to the name of our GPU,\n", "there are a few configuration details and top-level statistics.\n", "They are (tersely) documented\n", "[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)." ] }, { "cell_type": "markdown", "metadata": { "id": "MmBhUDgDLhd1" }, "source": [ "- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n", "this is effectively a coarse \"version number\" for your GPU hardware.\n", "It indexes which features are available,\n", "with more advanced features being available only at higher compute capabilities.\n", "It does not directly index the speed or memory of the GPU." ] }, { "cell_type": "markdown", "metadata": { "id": "voUgT6zuLyi0" }, "source": [ "- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase." ] }, { "cell_type": "markdown", "metadata": { "id": "Yl-IndtXE4b4" }, "source": [ "- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n", "for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n", "Tensor Cores.\n", "If you're running on an older GPU without Tensor Cores,\n", "you should consider upgrading.\n", "If you're running a more recent GPU but not seeing Tensor Core usage,\n", "you should switch to single precision floating point numbers,\n", "which Tensor Cores are specialized on." ] }, { "cell_type": "markdown", "metadata": { "id": "XxcUf0bBNXy_" }, "source": [ "- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n", "at a lower level than just whether something is running at all,\n", "as in utilization.\n", "Unlike utilization, reaching 100% is not generally feasible\n", "and sometimes not desirable.\n", "Increasing these numbers requires expertise in\n", "CUDA programming, so we'll target utilization instead." ] }, { "cell_type": "markdown", "metadata": { "id": "A88pQn4YMMKc" }, "source": [ "- **Execution Summary**: This table and pie chart indicates\n", "how much time within a profiled step\n", "was spent in each category.\n", "The value for \"kernel\" execution here\n", "is equal to the GPU utilization,\n", "and we want that number to be as close to 100%\n", "as possible.\n", "This summary helps us know which\n", "other operations are taking time,\n", "like memory being copied between CPU and GPU (`memcpy`)\n", "or `DataLoader`s executing on the CPU,\n", "so we can decide where the bottleneck is." ] }, { "cell_type": "markdown", "metadata": { "id": "6qjW1RlTQRPv" }, "source": [ "At the very bottom, you'll find a\n", "**Performance Recommendation**\n", "tab that sometimes suggests specific methods for improving performance.\n", "\n", "If this tab makes suggestions, you should certainly take them!" ] }, { "cell_type": "markdown", "metadata": { "id": "pWY5AhrcRQmJ" }, "source": [ "For more on using the profiler in TensorBoard,\n", "including some of the other, more detailed views\n", "available view the \"Views\" dropdown menu, see\n", "[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)." ] }, { "cell_type": "markdown", "metadata": { "id": "mQwrPY_H77H8" }, "source": [ "## Going deeper with the Chrome Trace Viewer" ] }, { "cell_type": "markdown", "metadata": { "id": "yhwo7fslvLsH" }, "source": [ "So far, we've seen summary-level information about our training steps\n", "in the table from Lightning and in the TensorBoard Overview.\n", "These give aggregate statistics about the computations that occurred,\n", "but understanding how to interpret those statistics\n", "and use them to speed up our networks\n", "requires understanding just what is\n", "happening in our training step.\n", "\n", "Fundamentally,\n", "all computations are processes that unfold in time.\n", "\n", "If we want to really understand our training step,\n", "we need to display it that way:\n", "what operations were occurring,\n", "on both the CPU and GPU,\n", "at each moment in time during the training step.\n", "\n", "This information on timing is collected in the trace.\n", "One of the best tools for viewing the trace over time\n", "is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)." ] }, { "cell_type": "markdown", "metadata": { "id": "wUkZItxYc20A" }, "source": [ "Let's tour the trace we just logged\n", "with an aim to really understanding just\n", "what is happening when we call\n", "`training_step`\n", "and by extension `.forward`, `.backward`, and `optimizer.step`." ] }, { "cell_type": "markdown", "metadata": { "id": "9w9F2UA7Qctg" }, "source": [ "The Chrome Trace Viewer is built into W&B,\n", "so we can view our traces in their interface.\n", "\n", "The cell below embeds the trace inside the notebook,\n", "but you may wish to open it separately,\n", "with the \"Open page\" button or by navigating to the URL,\n", "so that you can interact with it\n", "as you read the description below.\n", "Display directly on W&B is also a bit less temperamental\n", "than display on W&B inside a notebook.\n", "\n", "Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n", "so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n", "It also can interact poorly with browser extensions (e.g. ad blockers),\n", "so you may need to deactivate them temporarily in order to see it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OMUs4aby6Rfd" }, "outputs": [], "source": [ "trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n", "trace_url = trace_files_url + \"training_step.pt.trace.json\"\n", "\n", "example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n", "\n", "print(trace_url)\n", "IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qNVpGeQtQjMG" }, "source": [ "> **Heads up!** We're about to do a tour of the\n", "> precise details of the tracing information logged\n", "> during the execution of the training code.\n", "> The only way to learn how to troubleshoot model performance\n", "> empirically is to look at the details,\n", "> but the details depend on the precise machine being used\n", "> -- GPU and CPU and RAM.\n", "> That means even within Colab,\n", "> these details change from session to session.\n", "> So if you don't observe a phenomenon or feature\n", "> described in the tour below, check out\n", "> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n", "> on W&B while reading through the next section of the lab,\n", "> and return to your trace once you understand the trace viewer better at the end.\n", "> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n", "> can sometimes be a bit janky." ] }, { "cell_type": "markdown", "metadata": { "id": "kXMcBhnCgdN_" }, "source": [ "This trace reveals, in nanosecond-level detail,\n", "what's going on inside of a `training_step`\n", "on both the GPU and the CPU.\n", "\n", "Time is on the horizontal axis.\n", "Colored bars represent method calls,\n", "and the methods called by a method are placed underneath it vertically,\n", "a visualization known as an\n", "[icicle chart](https://www.brendangregg.com/flamegraphs.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "67BsNzDfVIeg" }, "source": [ "Let's orient ourselves with some gross features:\n", "the forwards pass,\n", "GPU kernel execution,\n", "the backwards pass,\n", "and the optimizer step." ] }, { "cell_type": "markdown", "metadata": { "id": "IBEFgtRCKqrh" }, "source": [ "### The forwards pass" ] }, { "cell_type": "markdown", "metadata": { "id": "5nYhiWesVMjK" }, "source": [ "Type in `resnet` to the search bar in the top-right.\n", "\n", "This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n", "\n", "It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n", "\n", "You can click the arrows next to that tile to partially collapse these blocks.\n", "\n", "Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n", "It should be at roughly the same height.\n", "\n", "Clear the search bar so that the trace is in color.\n", "Zoom in on the area of the forwards pass\n", "using the \"zoom\" tool in the floating toolbar,\n", "so you can see more detail.\n", "The zoom tool is indicated by a two-headed arrow\n", "pointing into and out of the screen.\n", "\n", "Switch to the \"drag\" tool,\n", "represented by a four-headed arrow.\n", "Click-and-hold to use this tool to focus\n", "on different parts of the timeline\n", "and click on the individual colored boxes\n", "to see details about a particular method call.\n", "\n", "As we go down in the icicle chart,\n", "we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n", "to much more precise `cudnn` and `cuda` operations\n", "(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n", "\n", "`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n", "is the tensor math library in PyTorch\n", "that links to specific backends like `cudnn`." ] }, { "cell_type": "markdown", "metadata": { "id": "Fq181ybIvLsH" }, "source": [ "### GPU kernel execution" ] }, { "cell_type": "markdown", "metadata": { "id": "IbkWp5aKvLsH" }, "source": [ "Towards the bottom, you should see a section labeled \"GPU\".\n", "The label appears on the far left.\n", "\n", "Within it, you'll see one or more \"`stream`s\".\n", "These are units of work on a GPU,\n", "akin loosely to threads on the CPU.\n", "\n", "When there are colored bars in this area,\n", "the GPU is doing work of some kind.\n", "The fraction of this bar that is filled in with color\n", "is the same as the \"GPU Utilization %\" we've seen previously.\n", "So the first thing to visually assess\n", "in a trace view of PyTorch code\n", "is what fraction of this area is filled with color.\n", "\n", "In CUDA, work is queued up to be\n", "placed into streams and completed, on the GPU,\n", "in a distributed and asynchronous manner.\n", "\n", "The selection of which work to do\n", "is happening on the CPU,\n", "and that's what we were looking at above.\n", "\n", "The CPU and the GPU have to work together to coordinate\n", "this work.\n", "\n", "Type `cuda` into the search bar and you'll see these coordination operations happening:\n", "`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n", "\n", "Running the same PyTorch model\n", "with the same high level operations like `Conv2d` in different versions of PyTorch,\n", "on different GPUs, and even on tensors of different sizes will result\n", "in different choices of concrete kernel operation,\n", "e.g. different matrix multiplication algorithms.\n", "\n", "Type `sync` into the search bar and you'll see places where either work on the GPU\n", "or work on the CPU needs to await synchronization,\n", "e.g. copying data from the CPU to the GPU\n", "or the CPU waiting to decide what to do next\n", "on the basis of the contents of a tensor.\n", "\n", "If you see a \"sync\" block above an area\n", "where the stream on the GPU is empty,\n", "you've got a performance bottleneck due to synchronization\n", "between the CPU and GPU.\n", "\n", "To resolve the bottleneck,\n", "head up the icicle chart until you reach the recognizable\n", "PyTorch modules and operations.\n", "Find where they are called in your PyTorch module.\n", "That's a good place to review your code to understand why the synchronization is happening\n", "and removing it if it's not necessary." ] }, { "cell_type": "markdown", "metadata": { "id": "XeMPbu_jvLsI" }, "source": [ "### The backwards pass\n", "\n", "Type in `backward` into the search bar.\n", "\n", "This will highlight components of our backwards pass.\n", "\n", "If you read it from left to right,\n", "you'll see that it begins by calculating the loss\n", "(`NllLoss2DBackward` in the search bar if you can't find it)\n", "and ends by doing a `ConvolutionBackward`,\n", "the first layer of the ResNet.\n", "It is, indeed, backwards.\n", "\n", "Like the forwards pass,\n", "the backwards pass also involves the CPU\n", "telling the GPU which kernels to run.\n", "It's typically run in a separate\n", "thread from the forwards pass,\n", "so you'll see it separated out from the forwards pass\n", "in the trace viewer.\n", "\n", "Generally, there's no need to specifically optimize the backwards pass --\n", "removing bottlenecks in the forwards pass results in a fast backwards pass.\n", "\n", "One reason why is that these two passes are just\n", "\"transposes\" of one another,\n", "so they share a lot of properties,\n", "and bottlenecks in one become bottlenecks in the other.\n", "We can choose to optimize either one of the two.\n", "But the forwards pass is under our direct control,\n", "so it's easier for us to reason about.\n", "\n", "Another reason is that the forwards pass is more likely to have bottlenecks.\n", "The forwards pass is a dynamic process,\n", "with each line of Python adding more to the compute graph.\n", "Backwards passes, on the other hand, use a static compute graph,\n", "the one just defined by the forwards pass,\n", "so more optimizations are possible." ] }, { "cell_type": "markdown", "metadata": { "id": "gWiDw0vCvLsI" }, "source": [ "### The optimizer step" ] }, { "cell_type": "markdown", "metadata": { "id": "ndfkzEdnvLsI" }, "source": [ "Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n", "\n", "As with the two passes,\n", "we are still using the CPU\n", "to launch kernels on the GPU.\n", "But now the CPU is looping,\n", "in Python, over the parameters\n", "and applying the ADAM updates rules to each.\n", "\n", "We now know enough to see that\n", "this is not great for our GPU utilization:\n", "there are many areas of gray\n", "in between the colored bars\n", "in the GPU stream in this area.\n", "\n", "In the time it takes CUDA to multiply\n", "thousands of numbers,\n", "Python has not yet finished cleaning up\n", "after its request for that multiplication.\n", "\n", "As of writing in August 2022,\n", "more efficient optimizers are not a stable part of PyTorch (v1.12), but\n", "[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n", "and stable implementations outside of PyTorch.\n", "The standard implementations are in\n", "[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n", "not to be confused with the\n", "[Apex Optimizers Project](https://www.apexoptimizers.com/),\n", "which is a collection of fitness-themed cheetah NFTs." ] }, { "cell_type": "markdown", "metadata": { "id": "WX0jxeafvLsI" }, "source": [ "## Take-aways for PyTorch performance bottleneck troubleshooting" ] }, { "cell_type": "markdown", "metadata": { "id": "CugD-bK2vLsI" }, "source": [ "Our goal here was to learn some basic principles and tools for bottlenecking\n", "the most common issues and the lowest-hanging fruit in PyTorch code." ] }, { "cell_type": "markdown", "metadata": { "id": "SwHwJkVMHYGA" }, "source": [ "\n", "Here's an overview in terms of a \"host\",\n", "generally the CPU,\n", "and a \"device\", here the GPU.\n", "\n", "- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n", "- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n", " - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n", "- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n", "stores information about the contents of tensors,\n", "not just their metadata." ] }, { "cell_type": "markdown", "metadata": { "id": "Gntx28p9cBP5" }, "source": [ "Towards that goal, we viewed the trace to get an understanding of\n", "what's going on inside a PyTorch training step." ] }, { "cell_type": "markdown", "metadata": { "id": "AKvZGPnkeXvq" }, "source": [ "Here's what we've means in terms of troubleshooting bottlenecks.\n", "\n", "We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n", "before the previous kernel finishes.\n", "\n", "Ideally, the CPU is actually getting far _ahead_ of execution\n", "on the GPU.\n", "If the CPU makes it all the way through the backwards pass before the GPU is done,\n", "that's great!\n", "The GPU(s) are the expensive part,\n", "and it's easy to use multiprocessing so that\n", "the CPU has other things to do.\n", "\n", "This helps explain at least one common piece of advice:\n", "the larger our batches are,\n", "the more work the GPU has to do for the same work done by the CPU,\n", "and so the better our utilization will be." ] }, { "cell_type": "markdown", "metadata": { "id": "XMztpa-TccH4" }, "source": [ "We operationalize our desire to never be waiting on the CPU with a simple metric:\n", "**100% GPU utilization**, meaning a kernel is running at all times.\n", "\n", "This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n", "\n", "You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost." ] }, { "cell_type": "markdown", "metadata": { "id": "7kYBygfScR6z" }, "source": [ "Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n", "1. **The CPU is too weak**.\n", "Because so much of the discussion around DNN performance is about GPUs,\n", "it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n", "_Resolution_:\n", "Use nice CPUs, like\n", "[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n", "2. **Too much Python during the `training_step`**.\n", "Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n", "that takes longer than running an entire layer.\n", "_Resolution_:\n", "Look for low utilization areas of the trace\n", "and check what's happening on the CPU at that time\n", "and carefully review the Python code being executed.\n", "3. **Unnecessary Host/Device synchronization**.\n", "If one of your operations depends on the values in a tensor,\n", "like `if xs.mean() >= 0`,\n", "you'll induce a synchronization between\n", "the host and the device and possibly lead\n", "to an expensive and slow copy of data.\n", "_Resolution_:\n", "Replace these operations as much as possible\n", "with purely array-based calculations.\n", "4. **Bottlenecking on the DataLoader**.\n", "In addition to coordinating the work on the GPU,\n", "CPUs often perform heavy data operations,\n", "including communication over the network\n", "and writing to/reading from disk.\n", "These are generally done in parallel to the forwards\n", "and backwards passes,\n", "but if they don't finish before that happens,\n", "they will become the bottleneck.\n", "_Resolution_:\n", "Get better hardware for compute,\n", "memory, and network.\n", "For software solutions, the answer \n", "is a bit more complex and application-dependent.\n", "For generic tips, see\n", "[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n", "in the PyTorch forums.\n", "For techniques in computer vision, see\n", "[the FFCV library](https://github.com/libffcv/ffcv)\n", "and for techniques in NLP, see e.g.\n", "[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n", "and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)." ] }, { "cell_type": "markdown", "metadata": { "id": "i2WYS8bQvLsJ" }, "source": [ "### Further steps in making DNNs go brrrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "T0wW2_lRKfY1" }, "source": [ "It's important to note that utilization\n", "is just an easily measured metric\n", "that can reveal common bottlenecks.\n", "Having high utilization does not automatically mean\n", "that your performance is fully optimized.\n", "\n", "For example,\n", "synchronization events between GPUs\n", "are counted as kernels,\n", "so a deadlock during distributed training\n", "can show up as 100% utilization,\n", "despite literally no useful work occurring.\n", "\n", "Just switching to \n", "double precision floats, `--precision=64`,\n", "will generally lead to much higher utilization.\n", "The GPU operations take longer\n", "for roughly the same amount of CPU effort,\n", "but the added precision brings no benefit.\n", "\n", "In particular, it doesn't make for models\n", "that perform better on our correctness metrics,\n", "like loss and accuracy.\n", "\n", "Another useful yardstick to add\n", "to utilization is examples per second,\n", "which incorporates how quickly the model is processing data examples\n", "and calculating gradients.\n", "\n", "But really,\n", "the gold star is _decrease in loss per second_.\n", "This metric connects model design choices\n", "and hyperparameters with purely engineering concerns,\n", "so it disrespects abstraction barriers\n", "and doesn't generally lead to actionable recommendations,\n", "but it is, in the end, the real goal:\n", "make the loss go down faster so we get better models sooner." ] }, { "cell_type": "markdown", "metadata": { "id": "EFzPsplfdo_o" }, "source": [ "For PyTorch internals abstractly,\n", "see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "For more on performance considerations in PyTorch,\n", "see [Horace He's blog post](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "yq6-S6TC38AY" }, "source": [ "### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n", "\n", "One of the most important features for making\n", "PyTorch run quickly is the\n", "`MultiprocessingDataLoader`,\n", "which executes batching of data in a separate process\n", "from the forwards and backwards passes.\n", "\n", "By default in PyTorch,\n", "this feature is actually turned off,\n", "via the `DataLoader` argument `num_workers`\n", "having a default value of `0`,\n", "but we set the `DEFAULT_NUM_WORKERS`\n", "to a value based on the number of CPUs\n", "available on the system running the code.\n", "\n", "Re-run the profiling cell,\n", "but set `num_workers` to `0`\n", "to turn off multiprocessing.\n", "\n", "Compare and contrast the two traces,\n", "both for total runtime\n", "(see the time axis at the top of the trace)\n", "and for utilization.\n", "\n", "If you're unable to run the profiles,\n", "see the results\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n", "which juxtaposes two traces,\n", "with in-process dataloading on the left and\n", "multiprocessing dataloading on the right." ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test." ] }, { "cell_type": "markdown", "metadata": { "id": "T2i_a5eVeIoA" }, "source": [ "The file below incorrectly implements and then incorrectly tests\n", "a simple PyTorch utility for adding five to every entry of a tensor\n", "and then calculating the sum.\n", "\n", "Even worse, it does it with horrible style!\n", "\n", "The cells below apply our linting checks\n", "(after automatically fixing the formatting)\n", "and run the test.\n", "\n", "Fix all of the lints,\n", "implement the function correctly,\n", "and then implement some basic tests." ] }, { "cell_type": "markdown", "metadata": { "id": "wSon2fB5VVM_" }, "source": [ "- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aYiRvU4HA84t" }, "outputs": [], "source": [ "%%writefile training/fixme.py\n", "import torch\n", "from training import run_experiment\n", "from numpy import *\n", "import random\n", "from pathlib import Path\n", "\n", "\n", "\n", "\n", "def add_five_and_sum(tensor):\n", " # this function is not implemented right,\n", " # but it's supposed to add five to all tensor entries and sum them up\n", " return 1\n", "\n", "def test_add_five_and_sum():\n", " # and this test isn't right either! plus this isn't exactly a docstring\n", " all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n", " all_fives = 5 * all_ones\n", " assert False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EXJpmvuzT1w0" }, "outputs": [], "source": [ "!pre-commit run black --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SRO-oJfdUrcQ" }, "outputs": [], "source": [ "!cat training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jM8NHxVbSEQD" }, "outputs": [], "source": [ "!pre-commit run --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kj0VMBSndtkc" }, "outputs": [], "source": [ "!pytest training/fixme.py" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab05_troubleshooting.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab07/notebooks/lab06_data.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 06: Data Annotation" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How the `IAM` handwriting dataset is structured on disk and how it is processed into an ML-friendly format\n", "- How to setup a [Label Studio](https://labelstud.io/) data annotation server\n", "- Just how messy data really is" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 6\n", "\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "DpvaHz9TEGwV" }, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gsXpeXi2EGwV" }, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "\n", "IFrame(src=\"https://fsdl.me/2022-lab-06-video-embed\", width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "XTkKzEMNR8XZ" }, "source": [ "# `IAMParagraphs`: From annotated data to a PyTorch `Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "3mQLbjuiwZuj" }, "source": [ "We've used the `text_recognizer.data` submodule\n", "and its `LightningDataModule`s -- `IAMLines` and `IAMParagraphs`\n", "for lines and paragraphs of handwritten text\n", "from the\n", "[IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database).\n", "\n", "These classes convert data from a database-friendly format\n", "designed for storage and transfer into the\n", "format our DNNs expect:\n", "PyTorch `Tensor`s.\n", "\n", "In this section,\n", "we'll walk through that process in detail.\n", "\n", "In the following section,\n", "we'll see how data\n", "goes from signals measured in the world\n", "to the format we consume here." ] }, { "cell_type": "markdown", "metadata": { "id": "499c23a6" }, "source": [ "## Dataset structure on disk" ] }, { "cell_type": "markdown", "metadata": { "id": "a3438d2e" }, "source": [ "We begin by downloading the raw data to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "18900eec" }, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM\n", "\n", "iam = IAM()\n", "iam.prepare_data()" ] }, { "cell_type": "markdown", "metadata": { "id": "a332f359" }, "source": [ "The `IAM` dataset is downloaded as zip file\n", "and then unzipped:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d6c44266" }, "outputs": [], "source": [ "from text_recognizer.metadata.iam import DL_DATA_DIRNAME\n", "\n", "\n", "iam_dir = DL_DATA_DIRNAME\n", "!ls {iam_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "8463c2d1" }, "source": [ "The unzipped dataset is not simple a flat directory of files.\n", "\n", "Instead, there are a number of subfolders,\n", "each of which contains a particular type of data or metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "536924f7" }, "outputs": [], "source": [ "iamdb = iam_dir / \"iamdb\"\n", "\n", "!du -h {iamdb}" ] }, { "cell_type": "markdown", "metadata": { "id": "b745a594" }, "source": [ "For example, the `task` folder contains metadata about canonical dataset splits:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "84c21f75" }, "outputs": [], "source": [ "!find {iamdb / \"task\"} | grep \"\\\\.txt$\"" ] }, { "cell_type": "markdown", "metadata": { "id": "mEb0Pdm4vIHe" }, "source": [ "We find the images of handwritten text in the `forms` folder.\n", "\n", "An individual \"datapoint\" in `IAM` is a \"form\",\n", "because the humans whose hands wrote the text were prompted to write on \"forms\",\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "945d5e3a" }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "\n", "form_fn, = !find {iamdb}/forms | grep \".jpg$\" | sort | head -n 1\n", "\n", "print(form_fn)\n", "Image(filename=form_fn, width=\"360\")" ] }, { "cell_type": "markdown", "metadata": { "id": "b9e9e384" }, "source": [ "Meanwhile, the `xml` files contain the data annotations,\n", "written out as structured text:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6add5c5a" }, "outputs": [], "source": [ "xml_fn, = !find {iamdb}/xml | grep \"\\.xml$\" | sort | head -n 1\n", "\n", "!cat {xml_fn} | grep -A 100 \"handwritten-part\" | grep \"" ] }, { "cell_type": "markdown", "metadata": { "id": "MX9n-Zed8G_T" }, "source": [ "# Lab 07: Deployment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What You Will Learn\n", "\n", "- How to convert PyTorch models into portable TorchScript binaries\n", "- How to use `gradio` to make a simple demo UI for your ML-powered applications\n", "- How to split out a model service from the frontend and spin up a publicly accessible application" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "45D6GuSwvT7d" }, "outputs": [], "source": [ "lab_idx = 7\n", "\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pzi8qYKI-njP" }, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "\n", "IFrame(src=\"https://fsdl.me/2022-lab-07-video-embed\", width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "SAw7BEI_sCZZ" }, "source": [ "# Making the model portable" ] }, { "cell_type": "markdown", "metadata": { "id": "8zL0K2Xe-MWJ" }, "source": [ "While training the model,\n", "we've saved checkpoints and stored them locally\n", "and on W&B.\n", "\n", "From these checkpoints, we can reload model weights\n", "and even restart training if we are in or can recreate\n", "the model development environment.\n", "\n", "We could directly deploy these checkpoints into production,\n", "but they're suboptimal for two reasons.\n", "\n", "First, as the name suggests,\n", "these \"checkpoints\" are designed for serializing\n", "state at a point of time in training.\n", "\n", "That means they can include lots of information\n", "not relevant during inference,\n", "e.g. optimizer states like running average gradients.\n", "\n", "Additionally, the model development environment\n", "is much more heavyweight than what we need during inference.\n", "\n", "For example, we've got Lightning for training models\n", "and W&B for tracking training runs.\n", "\n", "These in turn incur dependencies on lots of heavy data science libraries.\n", "\n", "We don't need this anymore -- we just want to run the model.\n", "\n", "These are effectively \"compiler tools\", which our runtime model doesn't need.\n", "\n", "So we need a new model binary artifact for runtime\n", "that's leaner and more independent.\n", "\n", "For this purpose, we use TorchScript." ] }, { "cell_type": "markdown", "metadata": { "id": "0bMPqKDjs623" }, "source": [ "## Compiling models to TorchScript" ] }, { "cell_type": "markdown", "metadata": { "id": "7d9EmZ0j_AQF" }, "source": [ "Torch has two main facilities for creating\n", "more portable model binaries:\n", "_scripting_ and _tracing_." ] }, { "cell_type": "markdown", "metadata": { "id": "h9PVzwjQ_YHg" }, "source": [ "Scripting produces a binary that combines\n", "constant `Tensor` values\n", "(like weights and positional embeddings)\n", "with a program that describes how to use them.\n", "\n", "The result is a program that creates a dynamic graph,\n", "as does a normal PyTorch program,\n", "but this program is written in a\n", "sub-dialect of Python called\n", "_TorchScript_.\n", "\n", "The [TorchScript sub-dialect of Python](https://pytorch.org/docs/stable/jit_language_reference.html#language-reference)\n", "is more performant\n", "and can even be run without a Python interpreter.\n", "\n", "For example, TorchScript programs can be executed in pure C++\n", "[using LibTorch](https://pytorch.org/tutorials/advanced/cpp_export.html).\n", "\n", "You can read more in the documentation for the primary method\n", "for scripting models, `torch.jit.script`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "h1VtGt_Xj_H7" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "torch.jit.script??" ] }, { "cell_type": "markdown", "metadata": { "id": "tUOm7G9ESi4s" }, "source": [ "The primary alternative to scripting is _tracing_,\n", "which runs the PyTorch module on a specific\n", "set of inputs and records, or \"traces\",\n", "the compute graph.\n", "\n", "You can read more about it in the documentation for the primary method\n", "for tracing models, `torch.jit.trace`,\n", "or just read the quick summary and comparison below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pn3QLOFNjuOa" }, "outputs": [], "source": [ "torch.jit.trace??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tracing versus Scripting for TorchScript" ] }, { "cell_type": "markdown", "metadata": { "id": "uP4TfihfBw9z" }, "source": [ "The traced program is generally faster than the scripted version,\n", "for models that are compatible with both tracing and scripting.\n", "\n", "Tracing produces a static compute graph,\n", "which means all control flow\n", "(`if`s or `for` loops)\n", "are effectively inlined.\n", "\n", "As written, our text recognizer has a loop with conditional breaking -- fairly typical for Transformers in autoregressive mode --\n", "so it isn't compatible with tracing.\n", "\n", "Furthermore, the static compute graph includes concrete choices of operations,\n", "e.g. specific CUDA kernels if tracing is run on the GPU.\n", "\n", "If you try to run the traced model on a system that doesn't support those kernels,\n", "it will crash.\n", "That means tracing must occur in the target deployment environment.\n", "\n", "Scripted models are much more portable, at the cost of both slower runtimes\n", "for a fixed hardware target and of some restrictions on how dynamic the Python code can be.\n", "\n", "We don't find the restrictions scripting places on Python code to be too onerous\n", "and in our experience, the performance gains are not worth the extra effort\n", "until the team size is larger,\n", "model serving hardware and strategy is more mature,\n", "and model release cycles are slower.\n", "\n", "For an alternative perspective that's more in favor of tracing\n", "and walks through how to mix-and-match scripting\n", "and tracing for maximum flexibility and performance, see\n", "[this blogpost](https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/)\n", "from\n", "[Detectron2](https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/)\n", "dev Yuxin Wu." ] }, { "cell_type": "markdown", "metadata": { "id": "cDARv-GdqtET" }, "source": [ "Choosing just one of scripting or tracing\n", "means we can use a high-level method\n", "from PyTorch Lightning,\n", "`to_torchscript`,\n", "to produce our scripted model binary\n", "and we don't need to touch our model code." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "udvnx7sBBklY" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "\n", "pl.LightningModule.to_torchscript??" ] }, { "cell_type": "markdown", "metadata": { "id": "iXftpJBizrM6" }, "source": [ "## Alternatives to TorchScript" ] }, { "cell_type": "markdown", "metadata": { "id": "QvFh_SW8v4p6" }, "source": [ "Though it has some sharp edges,\n", "TorchScript is a relatively easy to use tool\n", "for compiling neural networks written in PyTorch.\n", "\n", "If you're willing to tolerate more sharp edges,\n", "e.g. limited support for certain ops\n", "and a higher risk of subtle differences in behavior, the\n", "[Open Neural Network eXchange](https://onnx.ai/)\n", "format, ONNX, is a compilation target for\n", "[a wide variety of DNN libraries](https://onnx.ai/supported-tools.html),\n", "from `sklearn` and MATLAB\n", "to PyTorch and Hugging Face.\n", "\n", "A high-level utility for conversion to ONNX is also included\n", "in PyTorch Lightning, `pl.LightningModule.to_onnx`.\n", "\n", "Because it is framework agnostic,\n", "there's more and more varied tooling around ONNX,\n", "and it has smoother paths to\n", "compilation targets that can run DNNs\n", "at the highest possible speeds,\n", "like\n", "[NVIDIA's TensorRT](https://developer.nvidia.com/tensorrt)\n", "or\n", "[Apache TVM](https://tvm.apache.org/2017/08/17/tvm-release-announcement).\n", "\n", "TensorRT is the model format used in the\n", "[Triton Inference Server](https://github.com/triton-inference-server/server),\n", "a sort of \"kubernetes for GPU-accelerated DNNs\"\n", "that is, as of 2022,\n", "the state of the art in running deep networks\n", "at maximum throughput on server-grade GPUs.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "36dKPerevkhZ" }, "source": [ "## A simple script for compiling and staging models" ] }, { "cell_type": "markdown", "metadata": { "id": "93pc-NLrBR1A" }, "source": [ "To recap, our model staging workflow,\n", "which does the hand-off between training and production, looks like this:\n", "\n", "1. Get model weights and hyperparameters\n", "from a tracked training run in W&B's cloud storage.\n", "2. Reload the model as a `LightningModule` using those weights and hyperparameters.\n", "3. Call `to_torchscript` on it.\n", "4. Save that result to W&B's cloud storage.\n", "\n", "We provide a simple script to implement this process:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gqgiWO0tFktU" }, "outputs": [], "source": [ "%run training/stage_model.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "i4qEqMRkFsd4" }, "source": [ "Here in this notebook,\n", "rather than training or scripting a model ourselves,\n", "we'll just `--fetch`\n", "an already trained and scripted model binary:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c2wfjLmRDwrH" }, "outputs": [], "source": [ "%run training/stage_model.py --fetch --entity=cfrye59 --from_project=fsdl-text-recognizer-2021-training" ] }, { "cell_type": "markdown", "metadata": { "id": "I0uNnvjkCZzX" }, "source": [ "Note that we can use the metadata of the staged model\n", "to find the training run that generated the model weights.\n", "It requires two graph hops:\n", "find the run that created the staged TorchScript model\n", "then in that run,\n", "find the model checkpoint artifact\n", "and look for the run that created it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E9zJg44hCjRv" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "\n", "staged_model_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/artifacts/prod-ready/paragraph-text-recognizer/3e07efa34aec61999c5a/overview\"\n", "\n", "IFrame(staged_model_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we're deploying our first model,\n", "this doesn't feel that important --\n", "it's easy enough to find the training runs\n", "we've executed and connect them to the model in production.\n", "\n", "But as we train and release more models,\n", "this information will become harder to find\n", "and automation and API access will become more important.\n", "\n", "This will be especially true if we adopt more sophisticated rollout strategies,\n", "like A/B testing or canarying,\n", "as the application matures.\n", "\n", "Our system here is not robust enough to be Enterprise Grade™️ --\n", "marking models as \"in production\" is manual\n", "and there are no access control planes built in --\n", "but at least the information is preserved." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running our more portable model via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "X7d2WHSCHHHP" }, "source": [ "Now that our TorchScript model binary file is present,\n", "we can spin up our text recognizer\n", "with much less code.\n", "\n", "We just need a compatible version of PyTorch\n", "and methods to convert\n", "our generic data types\n", "(images, strings)\n", "to and from PyTorch `Tensor`s.\n", "\n", "We can put all this together in\n", "a single light-weight object,\n", "the `ParagraphTextRecognizer` class:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZGXZep-nDiDk" }, "outputs": [], "source": [ "from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer\n", "\n", "\n", "ParagraphTextRecognizer??\n", "\n", "ptr = ParagraphTextRecognizer()" ] }, { "cell_type": "markdown", "metadata": { "id": "uwVo6BoeGmTW" }, "source": [ "And from there,\n", "we can start running on images\n", "and inferring the text that they contain:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CMZlfIoeG3hy" }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "example_input = \"text_recognizer/tests/support/paragraphs/a01-077.png\"\n", "\n", "print(ptr.predict(example_input))\n", "Image(example_input)" ] }, { "cell_type": "markdown", "metadata": { "id": "I6AHq1TH44Jq" }, "source": [ "As usual,\n", "we write our Python code\n", "so that it can be imported as a module\n", "and run in a Jupyter notebook,\n", "for documentation and experimentation,\n", "and we make it executable as a script\n", "for easier automation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "igY7sd8eGGI3" }, "outputs": [], "source": [ "%run text_recognizer/paragraph_text_recognizer.py --help\n", "\n", "%run text_recognizer/paragraph_text_recognizer.py {example_input}" ] }, { "cell_type": "markdown", "metadata": { "id": "MvYmSN0rE2BP" }, "source": [ "Notice that the `filename` here can be a local file, a URL, or even a cloud storage URI.\n", "\n", "Rather than writing the logic for handling these different cases,\n", "we use the\n", "[`smart_open` library](https://pypi.org/project/smart-open/)." ] }, { "cell_type": "markdown", "metadata": { "id": "3WQ-P16VC94R" }, "source": [ "## Testing our model development pipeline" ] }, { "cell_type": "markdown", "metadata": { "id": "0kVq2iBJDZH5" }, "source": [ "Creating models is _the_ critical function of our code base,\n", "so it's important that we test it,\n", "at the very least with \"smoke tests\" that let us know\n", "if the code is completely broken.\n", "\n", "Right now we have tests for data loading and model training,\n", "but no tests for end-to-end model development,\n", "which combines data loading, model training, and model compilation.\n", "\n", "So we add a simple model development test\n", "that trains a model for a very small number of steps\n", "and then runs our staging script.\n", "\n", "This model development test script returns an error code (`exit 1`) if the process of\n", "building a model fails (`\"$FAILURE\" = true`).\n", "\n", "We use\n", "[the `||` operator](https://www.unix.com/shell-programming-and-scripting/42417-what-does-mean-double-pipe.html)\n", "to set the `FAILURE` variable to `true` if any of the key commands in model development fail." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XPkwFxklDA5V", "scrolled": false }, "outputs": [], "source": [ "!cat training/tests/test_model_development.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "pQ21iRDqFvxj" }, "source": [ "As a next step to improve the coverage of this test,\n", "we might compare the model's outputs\n", "on the same inputs before and after compilation." ] }, { "cell_type": "markdown", "metadata": { "id": "hyXZhgqEvfe9" }, "source": [ "### Cleaning up artifacts" ] }, { "cell_type": "markdown", "metadata": { "id": "l22DqhC4GIJT" }, "source": [ "The final few lines of the testing script mention\n", "\"`selecting for deletion`\" some artifacts." ] }, { "cell_type": "markdown", "metadata": { "id": "EbIW5okFGQv7" }, "source": [ "As we incorporate more of our code into testing\n", "and develop more models,\n", "the amount of information we are storing on W&B increases.\n", "\n", "We're already uploading model checkpoints, several gigabytes per model training run,\n", "and now we're also looking at uploading several hundred megabytes\n", "of model data per execution of our test." ] }, { "cell_type": "markdown", "metadata": { "id": "T7aBCfpuJVJV" }, "source": [ "Artifact storage is free up to 100GB,\n", "but storing more requires a paid account.\n", "\n", "That means it literally pays to clean up after ourselves.\n", "\n", "We use a very simple script to select certain artifacts for deletion.\n", " \n", "> ⚠️ **Don't use this untested demonstration script in important environments!** ⚠️\n", "We include options for `-v`erbose output and a `--dryrun` mode,\n", "which are both critical for destructive actions that have access\n", "to model weights that might cost $1000s to produce.\n", "\n", "See the `--help` below for more on cleaning up artifacts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8hSqzRplITVB" }, "outputs": [], "source": [ "%run training/cleanup_artifacts.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "BfB38ywTJDMT" }, "source": [ "## Tuning inference performance on CPU and GPU" ] }, { "cell_type": "markdown", "metadata": { "id": "zau0MRr1FPw-" }, "source": [ "Apart from compilation to TorchScript,\n", "the biggest difference for running the model in production\n", "is that now all of our operations occur on the CPU.\n", "\n", "This is a surprising feature of DNN deployment\n", "that's worth thinking about in detail.\n", "\n", "Why isn't it a given that deep network inference\n", "runs on GPUs, when that's so critical for deep network training?\n", "\n", "First,\n", "not many web applications use GPUs,\n", "so there aren't nearly as many good tools and techniques\n", "for deplyoing GPU-backed services.\n", "\n", "But there's another, deeper reason:\n", "GPUs are not as easy to run efficiently\n", "during inference as they are in training.\n", "\n", "In training,\n", "we use static or synthetic datasets\n", "and our training code is in charge\n", "of the query patterns.\n", "\n", "In particular,\n", "we can request exactly as many inputs\n", "as we want to produce a batch\n", "that makes optimal use\n", "of our expensive GPUs.\n", "\n", "In production, requests arrive independently,\n", "according to the whims of our users.\n", "\n", "This makes batching challenging,\n", "and by far the simplest service architecture\n", "just runs on each request as it arrives.\n", "\n", "But that tanks GPU utilization.\n", "\n", "GPUs are highly parallel computers,\n", "and batch is the easiest dimension to parallelize on --\n", "for example, we load the model weights into memory once,\n", "use them, and then release the memory.\n", "\n", "The cell below\n", "compares two traces\n", "for a GPU-accelerated\n", "Text Recognizer model running\n", "on a single input and on a batch.\n", "\n", "For a simple summary,\n", "you can compare the two profiles in TensorBoard\n", "([batch size 1 here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-labs-lab05_training/runs/1vj48h6j/tensorboard?workspace=user-cfrye59),\n", "[batch size 16 here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59)).\n", "\n", "GPU utilization,\n", "our baseline metric for model performance,\n", "is under 50% with batch size 1,\n", "as compared to >90% with batch size 16,\n", "which fills up GPU RAM.\n", "\n", "You can also look through the traces for more details:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B_NZPLWa-ZVP" }, "outputs": [], "source": [ "trace_comparison_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-labs-lab05_training/reports/Trace-Comparison-Batch-Size-16-vs-1--VmlldzoyNTg2MTU4\"\n", "\n", "print(trace_comparison_url)\n", "IFrame(src=trace_comparison_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "_6_U1OyU-Vsi" }, "source": [ "But performance during inference is not as simple \n", "as just \"maximize GPU utilization\".\n", "\n", "In particular, throughput for the GPU with batch size 16\n", "is over 2x better,\n", "one example per 8 ms vs\n", "one example per 40 ms,\n", "but latency is much worse.\n", "\n", "It takes 140ms to complete the batch of size 16.\n", "In the intervening time no examples are completed,\n", "and all 16 users are waiting on a response.\n", "\n", "For comparison,\n", "running one example at a time\n", "would get the first user's result\n", "in just 40 ms,\n", "but the total processing time for all 16 examples would be\n", "640 ms.\n", "\n", "For user experience, latency is critical,\n", "but for making the most efficient use of hardware,\n", "throughput is generally more important.\n", "\n", "During training, we care much less about latency\n", "and much more about computing gradients as fast as possible,\n", "so we aim for larger batch sizes.\n", "\n", "Because of the need for efficient use of hardware,\n", "running on single inputs isn't always feasible.\n", "\n", "The usual solution is to run a queue,\n", "which collects up enough requests for a batch.\n", "\n", "One of the easiest ways to do this as of writing in September 2022 is to use\n", "[`cog` by Replicate](https://github.com/replicate/cog),\n", "which both solves difficult issues with containerizing\n", "models with GPU acceleration \n", "and includes, as a beta feature, a built-in Redis queue\n", "for batching requests and responses.\n", "\n", "But note that we can't just run a queue that waits for,\n", "say, 16 user requests\n", "to build up, then runs them all.\n", "If 15 requests come in at once,\n", "but then no requests come for an hour,\n", "all 15 users will be waiting for an hour\n", "for their responses --\n", "much worse than just waiting a few hundred extra milliseconds!\n", "\n", "We need to make sure the queue flushes after a certain amount of time,\n", "regardless of how many requests it has received,\n", "complicating our implementation.\n", "\n", "Running single inputs on GPUs\n", "and running a naive queue\n", "are two different ways it's easy to accidentally tank latency\n", "while pursuing efficiency,\n", "at least for some fraction of cases.\n", "\n", "So we stick with CPU inference." ] }, { "cell_type": "markdown", "metadata": { "id": "te-CYidTslPo" }, "source": [ "# Building a simple model UI" ] }, { "cell_type": "markdown", "metadata": { "id": "4kGXwQvjJq32" }, "source": [ "With compilation,\n", "we've moved from a model that can only run\n", "in a very special environment\n", "and with lots of support code\n", "into something lightweight\n", "that runs with a simple CLI.\n", "\n", "If we want users to send data to our model\n", "and get useful predictions out,\n", "we need to create a UI.\n", "\n", "But a CLI is not a UI --\n", "it's at best the foundation out of which a UI is built.\n", "\n", "This is not just a concern once the model is finished:\n", "a UI is an incredible tool for model debugging.\n", "\n", "It's hard to overstate the difference between\n", "a static, CLI or code-writing workflow\n", "for sending information to a model\n", "and an interactive interface.\n", "\n", "When your model is easily accessible on a mobile phone,\n", "when you can copy-paste text from elsewhere on your machine or the internet,\n", "or when you can upload arbitrary files,\n", "the whole range of possible inputs becomes clear\n", "in a way that's very hard to replicate with fixed data sets." ] }, { "cell_type": "markdown", "metadata": { "id": "S163btePLB1K" }, "source": [ "Unfortunately, creating a GUI from scratch is not easy,\n", "especially in Python.\n", "\n", "The best tool for GUIs is the browser,\n", "but the lingua franca of the browser\n", "is JavaScript\n", "([for now](https://webassembly.org/)).\n", "\n", "As full stack deep learning engineers,\n", "we're already writing Python with C/C++ acceleration,\n", "we're gluing scripts together with Bash,\n", "and we need to know enough SQL to talk to databases.\n", "\n", "Do we now need to learn front-end web development too?" ] }, { "cell_type": "markdown", "metadata": { "id": "oSeBo0MzL0H9" }, "source": [ "In the long term, it's a good investment,\n", "and we recommend\n", "[The Odin Project](https://www.theodinproject.com/),\n", "a free online course and community for learning web development.\n", "\n", "Their\n", "[Foundations course](https://www.theodinproject.com/paths/foundations/courses/foundations#html-foundations),\n", "starting from HTML foundations and proceeding\n", "through basic CSS\n", "and JavaScript,\n", "is a great way to dip your toes in\n", "and learn enough about building websites and UIs\n", "in the browser to be dangerous." ] }, { "cell_type": "markdown", "metadata": { "id": "q-7pJcsCL_84" }, "source": [ "In the short term,\n", "we write our frontends in Python libraries\n", "that effectively write the frontend JavaScript/CSS/HTML\n", "for us.\n", "\n", "For the past few years,\n", "[Streamlit](https://streamlit.io/)\n", "has been a popular choice for the busy Python data scientist.\n", "\n", "It remains a solid choice,\n", "and tooling for building complex apps with Streamlit is more mature." ] }, { "cell_type": "markdown", "metadata": { "id": "xey5gzr5tV51" }, "source": [ "We use the\n", "[`gradio` library](https://gradio.app/),\n", "which includes a simple API for wrapping\n", "a single Python function into a frontend\n", "in addition to a less mature, lower-level API\n", "for building apps more flexibly.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2XvUr7irMHQ6" }, "source": [ "This iteration of the FSDL codebase\n", "includes a new module,\n", "`app_gradio`,\n", "that makes a simple UI for the Text Recognizer\n", "using `gradio`.\n", "\n", "The core component is a script,\n", "`app_gradio/app.py`,\n", "that can be used to spin up our model and UI\n", "from the command line:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w2Ra8ot292XX" }, "outputs": [], "source": [ "%run app_gradio/app.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "J9bP3zFo9_YY" }, "source": [ "But one very nice feature of `gradio`\n", "is that it is designed to run as easily\n", "from the notebook as from the command line.\n", "\n", "Let's import the contents of `app.py`\n", "and take a look,\n", "then launch our UI." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vryi5r6gDj6D" }, "outputs": [], "source": [ "from app_gradio import app\n", "\n", "\n", "app.make_frontend??\n", "frontend = app.make_frontend(ptr.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use `gradio`'s high-level API, `gr.Interface`,\n", "to build a UI by wrapping our `ptr.predict` function,\n", "defining its inputs\n", "(an `Image`)\n", "and outputs\n", "(a `TextBox`),\n", "and specifying some formatting\n", "and styling choices." ] }, { "cell_type": "markdown", "metadata": { "id": "m0HxOukBNn13" }, "source": [ "\n", "\n", "We can spin up our UI with the `.launch` method,\n", "and now we can interact\n", "with the model from inside the notebook.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XoVFtGbuDlTL" }, "outputs": [], "source": [ "frontend.launch(share=True, width=\"100%\")" ] }, { "cell_type": "markdown", "metadata": { "id": "okcoAW7sM13h" }, "source": [ "For 72 hours, we can also access the model over the public internet\n", "using a URL provided by `gradio`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x5pEhMECNIT6" }, "outputs": [], "source": [ "print(frontend.share_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "LYfi-lZqNNZd" }, "source": [ "You can point your browser to that URL\n", "to see what the model looks like as a full-fledged web application,\n", "instead of a widget inside the notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "2L5uZCJlOGi4" }, "source": [ "In addition to this UI,\n", "`gradio` also creates a simple REST API,\n", "so we can make requests\n", "from outside the browser,\n", "programmatically,\n", "and get responses." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XOngmAWvQnqg" }, "outputs": [], "source": [ "%env API_URL={frontend.share_url + \"/api\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "cj6XSur7Nlzf" }, "source": [ "We can see the details of the API by clicking\n", "\"view api\" at the bottom of the Gradio interface.\n", "\n", "In particular,\n", "we can see that the API expects image data in\n", "[base64 format](https://developer.mozilla.org/en-US/docs/Glossary/Base64),\n", "which encodes binary data as ASCII text\n", "so that it can be sent over interfaces that expect ASCII text." ] }, { "cell_type": "markdown", "metadata": { "id": "igeFyT84WqqG" }, "source": [ "The line below encodes an image with the `base64` utility,\n", "packages it into the appropriate JSON format\n", "and uses `echo` to pipe it into a `curl` command.\n", "\n", "`curl` can be used to make requests to web services at URLs\n", "-- here `${API_URL}/predict` --\n", "of specific types\n", "-- here `POST` --\n", "that include `-d`ata\n", "and `-H`eaders identifying the format of the data.\n", "\n", "The response is returned as\n", "[string-formatted JSON](https://developer.mozilla.org/en-US/docs/Learn/JavaScript/Objects/JSON)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_nmRbYQCOd3t" }, "outputs": [], "source": [ "response, = ! \\\n", " (echo -n '{ \"data\": [\"data:image/png;base64,'$(base64 -w0 -i text_recognizer/tests/support/paragraphs/a01-077.png)'\"] }') \\\n", " | curl -s -X POST \"${API_URL}/predict\" -H 'Content-Type: application/json' -d @-\n", " \n", "response" ] }, { "cell_type": "markdown", "metadata": { "id": "tLy9z593X4_o" }, "source": [ "JSON, short for \"JavaScript Object Notation\",\n", "is effectively the standard for representing dictionaries\n", "when sharing information between applications\n", "that may be written in different languages.\n", "\n", "With the standard library's `json.loads`,\n", "we can convert the response into a Python dictionary\n", "and then access the response `data` within." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GL4L8o4KRQLx" }, "outputs": [], "source": [ "import json\n", "\n", "\n", "print(json.loads(response)[\"data\"][0])" ] }, { "cell_type": "markdown", "metadata": { "id": "rhOc0fgrRtuO" }, "source": [ "Importantly, the `echo | curl` command\n", "does not need to be run from the same machine that is running the model --\n", "that's another big win for this UI over the CLI script we ran previously.\n", "\n", "Try running the command from your own machine,\n", "if you are running OS X or Linux,\n", "and see if you can get a response.\n", "\n", "Don't forget to define the `API_URL` environment variable on your machine\n", "and download the image file,\n", "`text_recognizer/tests/support/paragraphs/a01-077.png`,\n", "changing the path if needed." ] }, { "cell_type": "markdown", "metadata": { "id": "cd1UZiM3ZVWz" }, "source": [ "Once you're done,\n", "turn off the Gradio interface by running the `.close` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mVyv6KjxJhEb" }, "outputs": [], "source": [ "frontend.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "qnJpCdI7SHiX" }, "source": [ "## Testing our UI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've added a lot of new functionality here,\n", "and some of it is critical to our application.\n", "\n", "The surface area is too large and\n", "the components too complex for testing in depth\n", "to be worth the investment --\n", "do we really want to set up a\n", "[headless browser](https://www.browserstack.com/guide/what-is-headless-browser-testing)\n", "or similar mock test to check whether our README is being loaded properly?\n", "\n", "So once again, we pick the minimal test that checks whether\n", "the core functionality is working:\n", "we spin up our frontend and ping the API,\n", "making sure we get back a\n", "[`200 OK`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/200)\n", "response, indicating that at least the server thinks everything is fine." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!cat app_gradio/tests/test_app.py" ] }, { "cell_type": "markdown", "metadata": { "id": "IwUhy-swZndq" }, "source": [ "## Start here, finish anywhere" ] }, { "cell_type": "markdown", "metadata": { "id": "FTKMGCasMznl" }, "source": [ "You may be concerned:\n", "is `gradio` a children's toy?\n", "am I painting myself into a corner\n", "by using such a high-level framework and doing web development in Python?\n", "shouldn't I be using Ruby On Rails/Angular/React/WhateversNext.js?\n", "\n", "DALL-E Mini, now\n", "[crAIyon](https://www.craiyon.com/),\n", "began its life as\n", "[a Gradio app](https://huggingface.co/spaces/dalle-mini/dalle-mini)\n", "built by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma).\n", "\n", "Gradio and similar tools\n", "are critical for quickly getting to an MVP\n", "and getting useful feedback on your model.\n", "\n", "Expend your engineering effort on data and training,\n", "not frontend interface development,\n", "until you're sure you've got something people want to use." ] }, { "cell_type": "markdown", "metadata": { "id": "8BpPtj6tsP-Y" }, "source": [ "# Wrapping a model into a model service" ] }, { "cell_type": "markdown", "metadata": { "id": "ButF0a6PSbMi" }, "source": [ "We've got an interactive interface for our model\n", "that we can share with friends, colleagues,\n", "potential users, or stakeholders,\n", "which is huge.\n", "\n", "But we have a problem:\n", "our model is running in the same place as our frontend.\n", "\n", "This is simple,\n", "but it ties too many things together.\n", "\n", "First, it ties together execution of the two components.\n", "\n", "If the model has a heart attack due to misformatted inputs\n", "or some mysterious DNN bug,\n", "the server goes down.\n", "The same applies in reverse --\n", "the only API for the model is provided by `gradio`,\n", "so a frontend issue means the model is inaccessible.\n", "\n", "Additionally, it ties together dependencies,\n", "since our server and our model are in the same\n", "environment.\n", "\n", "Lastly, it ties together the hardware used to run our\n", "server and our model.\n", "\n", "That's bad because the server and the model scale differently.\n", "Running the server at scale has different memory and computational requirements\n", "than does running the model at scale." ] }, { "cell_type": "markdown", "metadata": { "id": "HNoMc7fRcETy" }, "source": [ "We could just run another server --\n", "even writing it in Gradio if we wanted! --\n", "for the model.\n", "This is common with GPU inference,\n", "especially when doing queueing, cacheing,\n", "and other advanced techniques for improving\n", "model efficiency and latency.\n", "\n", "But that's potentially expensive --\n", "we're running two machines,\n", "which costs twice as much.\n", "\n", "Furthermore, this setup is harder to scale \"horizontally\".\n", "\n", "We'll pretty quickly need a solution for auto-scaling\n", "our two servers independently,\n", "e.g. directly in a container orchestration service, like\n", "[Kubernetes](https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/),\n", "or in a managed version of the same, like\n", "[Elastic Kubernetes Service](https://aws.amazon.com/eks/),\n", "or with an infrastructure automation tool, like\n", "[Terraform](https://www.terraform.io/)." ] }, { "cell_type": "markdown", "metadata": { "id": "0WI0H6Imcz_h" }, "source": [ "Luckily, there is an easier way, because our model service-plus-UI\n", "combo fits into a common pattern.\n", "\n", "We have a server that we want to be up all the time,\n", "ready to take requests,\n", "but we really only need\n", "the model service to run when a request hits.\n", "\n", "And apart from its environment (which includes the weights),\n", "the model only needs the request in order to produce a result.\n", "\n", "It does not need to hold onto any information in between executions --\n", "it is _stateless_.\n", "\n", "This pattern is common enough that all cloud providers\n", "offer a solution that takes the pain out of scaling\n", "the stateless component:\n", "\"serverless cloud functions\",\n", "so named because\n", "- they are run intermittently, rather than 24/7, like a server.\n", "- they are run on cloud infrastructure.\n", "- they are, as in\n", "[purely functional programming](https://en.wikipedia.org/wiki/Purely_functional_programming)\n", "or in mathematics, \"pure\" functions of their inputs,\n", "with no concept of state." ] }, { "cell_type": "markdown", "metadata": { "id": "eE_FhWxLhhxG" }, "source": [ "We use AWS's serverless offering,\n", "[AWS Lambda](https://aws.amazon.com/lambda/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xw3Una-yJ_mP" }, "outputs": [], "source": [ "from api_serverless import api\n", "\n", "api??" ] }, { "cell_type": "markdown", "metadata": { "id": "FGAeXmfFiYOi" }, "source": [ "Our main function here, `api.handler`, wraps `ParagraphTextRecognizer.predict`.\n", "\n", "Effectively, `api.handler` maps HTTP requests (`event`s) with AWS's canonical format\n", "to a format our `ParagraphTextRecognizer` understands,\n", "then converts the text recognizer's output into something\n", "that AWS understands.\n", "\n", "Deploying models as web services is an exercise in taking\n", "the Tensor-to-Tensor-mappings we work with in model development\n", "and wrapping them so that they run in the\n", "JSON-to-JSON-mapping world of web services." ] }, { "cell_type": "markdown", "metadata": { "id": "TDMPQKXqr7pS" }, "source": [ "## Talking to a model service" ] }, { "cell_type": "markdown", "metadata": { "id": "V41-UiMct92x" }, "source": [ "Setting up a serverless function on AWS requires an account\n", "(which requires putting down a credit card)\n", "and configuration of permissions\n", "(which is error-prone).\n", "\n", "If you want to see how that process works,\n", "check out our\n", "[\"bonus notebook\" on serverless deployment on AWS Lambda](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/notebooks/lab99_serverless_aws.ipynb).\n", "Heads up: it uses Docker,\n", "which means it's not compatible with Google Colab.\n", "\n", "So we'll skip that step and,\n", "like Julia Child or Martha Stewart, check out\n", "[one that was prepared earlier](https://tvtropes.org/pmwiki/pmwiki.php/Main/OneIPreparedEarlier).\n", "\n", "The cell below sends a request\n", "to a serverless cloud function running on the FSDL AWS account.\n", "\n", "This request is\n", "much like the one we sent to the API provided by `gradio`,\n", "but we here construct and send it in Python,\n", "using the `requests` library,\n", "rather than operating from the command line.\n", "\n", "When playing around with an API,\n", "writing requests and parsing responses \"by hand\"\n", "in the command line is helpful,\n", "but once we're working on real use cases for the API,\n", "we'll want to use higher-level libraries\n", "with good code quality and nice integrations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "76HwEP2Vzz3F" }, "outputs": [], "source": [ "import json\n", "\n", "from IPython.display import Image\n", "import requests # the preferred library for writing HTTP requests in Python\n", "\n", "lambda_url = \"https://3akxma777p53w57mmdika3sflu0fvazm.lambda-url.us-west-1.on.aws/\"\n", "image_url = \"https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png\"\n", "\n", "headers = {\"Content-Type\": \"application/json\"} \n", "payload = json.dumps({\"image_url\": image_url})\n", "\n", "response = requests.post( # we POST the image to the URL, expecting a prediction as a response\n", " lambda_url, data=payload, headers=headers)\n", "pred = response.json()[\"pred\"] # the response is also json\n", "\n", "print(pred)\n", "\n", "Image(url=image_url, width=512)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before deploying a service like this one,\n", "it's important to check how well it handles different traffic volumes and patterns.\n", "This process is known as _load-testing_.\n", "\n", "For a quick tutorial on some basic tooling and a run-through of\n", "results from load-testing the FSDL Text Recognizer on AWS Lambda, see\n", "[this \"bonus notebook\" on load-testing](https://fsdl.me/loadtesting-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "bZQ2Dt4URN9o" }, "source": [ "## Local in the front, serverless in the back" ] }, { "cell_type": "markdown", "metadata": { "id": "XMXWTHt4Pxpr" }, "source": [ "The primary \"win\" here\n", "is that we don't need to run\n", "the frontend UI server\n", "and the backend model service in\n", "the same place.\n", "\n", "For example,\n", "we can run a Gradio app locally\n", "but send the images to the serverless function\n", "for prediction.\n", "\n", "Our `app_gradio` implementation supports this via the `PredictorBackend`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4qZ1K0fwOtYK" }, "outputs": [], "source": [ "serverless_backend = app.PredictorBackend(url=lambda_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "5NVVU2JEPSpy" }, "source": [ "Previously, our `PredictorBackend`\n", "was just a wrapper around the `ParagraphTextRecognizer` class.\n", "\n", "By passing a URL,\n", "we switch to sending data elsewhere via an HTTP request.\n", "\n", "This is done by the\n", "`_predict_from_endpoint` method,\n", "which runs effectively the same code we used\n", "to talk to the model service in the cell above." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HtSppJq2O_B_" }, "outputs": [], "source": [ "serverless_backend._predict_from_endpoint??" ] }, { "cell_type": "markdown", "metadata": { "id": "DKA68zxUUO9e" }, "source": [ "The frontend doesn't care where the inference is getting done or how.\n", "\n", "A `gradio.Interface`\n", "just knows there's a Python function that it invokes and then \n", "waits for outputs from.\n", "\n", "Here, that Python function\n", "makes a request to the serverless backend,\n", "rather than running the model.\n", "\n", "Go ahead and try it out!\n", "\n", "You won't notice a difference,\n", "except that the machine you're running this notebook on\n", "no longer runs the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WEkMzohnOcK0" }, "outputs": [], "source": [ "frontend_serverless_backend = app.make_frontend(serverless_backend.run)\n", "\n", "frontend_serverless_backend.launch(share=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "XytXrIWVuRFu" }, "source": [ "# Serving a `gradio` app with `ngrok`" ] }, { "cell_type": "markdown", "metadata": { "id": "2i64HrL1wa7F" }, "source": [ "We've now got a model service and a web server\n", "that we can stand up and scale independently,\n", "but we're not quite done yet.\n", "\n", "First, our URL is controlled by Gradio.\n", "\n", "Very quickly once we leave the territory of a minimal demo,\n", "we'll want that URL to be branded.\n", "\n", "Relatedly,\n", "you may have noticed messages indicating that the public URL\n", "from Gradio is only good for 72 hours.\n", "\n", "That means we'd have to reset our frontend\n", "and share a new URL every few days." ] }, { "cell_type": "markdown", "metadata": { "id": "clsPvqtJu0V0" }, "source": [ "For projects that are mostly intended as public demos,\n", "you might follow the advice from those printed warnings\n", "and use\n", "[Hugging Face Spaces](https://huggingface.co/docs/hub/spaces)\n", "for free, permanent hosting.\n", "\n", "This relieves you of the burden of keeping the frontend server running.\n", "\n", "However, note that this requires you to use the Hugging Face Hub\n", "as a remote for your `git` repository, alongside GitHub or GitLab.\n", "This connection to the version control system can make for tricky integration,\n", "e.g. the need to create a new repository for each new model.\n", "\n", "By default, the demo is embedded inside Hugging Face,\n", "limiting your control over the look and feel.\n", "\n", "However, you can embed the demo in another website with\n", "[Web Components or IFrames](https://gradio.app/sharing_your_app/#embedding-with-web-components).\n", "You can also adapt the aesthetics and interactivity of the demo with\n", "[custom CSS and JS](https://gradio.app/custom_CSS_and_JS/).\n", "\n", "We will instead run the frontend server ourselves\n", "and provide a public URL\n", "without relying on Gradio's service." ] }, { "cell_type": "markdown", "metadata": { "id": "XWxKXSSG0yNX" }, "source": [ "Half of the work is already done for us:\n", "the `gradio` frontend is already listening on a port and IP address\n", "that is accessible locally\n", "(on `127.0.0.1` or `localhost`, as printed below)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ugupgc1bxQlH" }, "outputs": [], "source": [ "frontend_serverless_backend.local_url" ] }, { "cell_type": "markdown", "metadata": { "id": "GWcQa-ks1Ktn" }, "source": [ "So we can, for example, send `curl` requests locally,\n", "i.e. on the same machine as the frontend,\n", "and get responses." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z4JRaVjH0kPw" }, "outputs": [], "source": [ "# we send an improperly formatted request, because we just want to check for a response\n", "\n", "!curl -X POST {frontend_serverless_backend.local_url}api/predict" ] }, { "cell_type": "markdown", "metadata": { "id": "ZK4-tPGf32Hf" }, "source": [ "Running the same command on another machine will result in an error --\n", "`127.0.0.1` and `localhost` always mean \"on this machine\"." ] }, { "cell_type": "markdown", "metadata": { "id": "Eiwa6INa0PGe" }, "source": [ "So fundamentally,\n", "the goal is to take the frontend service\n", "running on an IP and port that is only accessible locally\n", "and make it accessible globally." ] }, { "cell_type": "markdown", "metadata": { "id": "Cuuj13Xk0M0Q" }, "source": [ "There's some tricky bits here --\n", "for example, you'll want to communicate using encryption,\n", "i.e. over HTTPS instead of HTTP --\n", "that make doing this entirely on your own\n", "a bit of a headache.\n", "\n", "To avoid these issues,\n", "we can once again use\n", "[`ngrok`](https://ngrok.com/),\n", "the service we used to provide access to our Label Studio instance\n", "in the data annotation lab.\n", "\n", "The free tier includes public URLs and secure communication with HTTPS.\n", "\n", "However, the URL changes each time you relaunch your service,\n", "e.g. after an outage or a version update.\n", "\n", "The paid tier allows for branded domains,\n", "simpler authentication with\n", "[OAuth](https://oauth.net/),\n", "and some basic scaling tools like load balancing.\n", "\n", "This is what we use for the official FSDL text recognizer at\n", "[fsdl-text-recognizer.ngrok.io](https://fsdl-text-recognizer.ngrok.io/)." ] }, { "cell_type": "markdown", "metadata": { "id": "IoKA_VUr4Gf2" }, "source": [ "To get started, let's\n", "set up our `ngrok` credentials." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3N2jkwdaLZAu" }, "outputs": [], "source": [ "import os\n", "import getpass\n", "\n", "from pyngrok import ngrok\n", "\n", "config_file = ngrok.conf.DEFAULT_NGROK_CONFIG_PATH\n", "config_file_exists = os.path.exists(config_file)\n", "config_file_contents = !cat {config_file}\n", "\n", "auth_token_found = config_file_exists \\\n", " and config_file_contents \\\n", " and \"authtoken\" in config_file_contents[0] \\\n", " and \": exit\" not in config_file_contents # state if interrupted\n", "\n", "if not auth_token_found:\n", " print(\"Enter your ngrok auth token, which can be copied from https://dashboard.ngrok.com/auth\")\n", " !ngrok authtoken {getpass.getpass()}" ] }, { "cell_type": "markdown", "metadata": { "id": "m3SaBJn14YA_" }, "source": [ "From there,\n", "it's as simple as pointing\n", "an `ngrok` tunnel\n", "at the port associated with your frontend.\n", "\n", "> For our purposes, ports are\n", "\"places you can listen for messages to your web service\".\n", "By separating ports,\n", "which are identifiers within a machine,\n", "from URLs/IPs,\n", "which are identifiers across machines,\n", "we can run multiple services on a single machine." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wURZiaA5LkeF" }, "outputs": [], "source": [ "TEXT_RECOGNIZER_PORT = frontend_serverless_backend.server_port\n", "\n", "https_tunnel = ngrok.connect(TEXT_RECOGNIZER_PORT, bind_tls=True)\n", "print(https_tunnel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Head to the printed `ngrok.io` URL from any device --\n", "e.g. a mobile phone --\n", "to check out your shiny new ML-powered application UI\n", "with serverless backend." ] }, { "cell_type": "markdown", "metadata": { "id": "XWYBGHLs5iwN" }, "source": [ "Running a web service out of a Jupyter notebook is not recommended.\n", "\n", "`gradio` and `ngrok`\n", "can be run from the command line.\n", "\n", "If you're running the lab locally,\n", "just define the `TEXT_RECOGNIZER_PORT`\n", "and `LAMBDA_URL` environment variables\n", "and then run\n", "\n", "```bash\n", "python app_gradio/app.py --model_url $LAMBDA_URL --model_port $TEXT_RECOGNIZER_PORT\n", "```\n", "\n", "in one terminal\n", "and, in a separate terminal,\n", "run\n", "```bash\n", "ngrok $TEXT_RECOGNIZER_PORT https\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nycSygTy-PcQ" }, "source": [ "and navigate to the printed URL." ] }, { "cell_type": "markdown", "metadata": { "id": "oQCpzYzHRGfd" }, "source": [ "## Launching a server on a cloud instance" ] }, { "cell_type": "markdown", "metadata": { "id": "RKKnzQjmQPV8" }, "source": [ "We are almost, but not quite,\n", "to the point of a reasonably professional web service.\n", "\n", "The last missing piece is that our server is running\n", "either on Colab,\n", "which has short uptimes and is not intended for serving,\n", "or on our own personal machine,\n", "which is also likely a few\n", "[nines](https://en.wikipedia.org/wiki/High_availability#Percentage_calculation) short of an uptime SLA." ] }, { "cell_type": "markdown", "metadata": { "id": "IKOuYfpTQR-c" }, "source": [ "We want to instead run this on a dedicated server,\n", "and the simplest way to do so is to spin up a machine in a cloud provider.\n", "\n", "[Elastic Compute Cloud](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/concepts.html)\n", "(aka EC2)\n", "is the option in AWS,\n", "our chosen cloud provider." ] }, { "cell_type": "markdown", "metadata": { "id": "15NI6gI1746O" }, "source": [ "To get the server going on another machine,\n", "we'll need to `git clone` our library,\n", "`pip install` our `prod` requirements,\n", "and then finally run `ngrok` and `app_gradio/app.py`." ] }, { "cell_type": "markdown", "metadata": { "id": "faStq6aV-hci" }, "source": [ "We can make that process slightly easier\n", "by incorporating it into a `Dockerfile`\n", "and building a container image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_1i0M7hR-moU" }, "outputs": [], "source": [ "!cat app_gradio/Dockerfile" ] }, { "cell_type": "markdown", "metadata": { "id": "jskTeGs9AroE" }, "source": [ "We can then store the container image in a registry, like\n", "[Docker Hub](https://hub.docker.com/)\n", "or the container image registry built into our cloud provider, like AWS's\n", "[Elastic Container Registry](https://aws.amazon.com/ecr/).\n", "\n", "Then, setup just means pulling the image down onto the machine\n", "we want to run our server from and executing a `docker run` command." ] } ], "metadata": { "colab": { "collapsed_sections": [], "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab07/tasks/lint.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false # apply automatic formatting echo "black" pre-commit run black || FAILURE=true # check for python code style violations, see .flake8 for details echo "flake8" pre-commit run flake8 || FAILURE=true # check for shell scripting style violations and common bugs echo "shellcheck" pre-commit run shellcheck || FAILURE=true # check python types echo "mypy" pre-commit run mypy || FAILURE=true if [ "$FAILURE" = true ]; then echo "Linting failed" exit 1 fi echo "Linting passed" exit 0 ================================================ FILE: lab07/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab07/text_recognizer/callbacks/__init__.py ================================================ from .model import ModelSizeLogger from .optim import LearningRateMonitor from . import imtotext from .imtotext import ImageToTextTableLogger as ImageToTextLogger ================================================ FILE: lab07/text_recognizer/callbacks/imtotext.py ================================================ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only try: import wandb has_wandb = True except ImportError: has_wandb = False from .util import check_and_warn class ImageToTextTableLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.on_train: if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "validation/predictions") def _log_image_text_table(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx] xs = [wandb.Image(x) for x in xs] rows = zip(*[xs, gt_strs, pred_strs]) columns = ["input_image", "ground_truth_string", "predicted_string"] trainer.logger.log_table(key=key, columns=columns, data=list(rows)) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) class ImageToTextCaptionLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "validation/predictions") @rank_zero_only def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "test/predictions") def _log_image_text_caption(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx] trainer.logger.log_image(key, xs, caption=pred_strs) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) ================================================ FILE: lab07/text_recognizer/callbacks/model.py ================================================ import os from pathlib import Path import tempfile import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_only import torch from .util import check_and_warn, logging try: import torchviz has_torchviz = True except ImportError: has_torchviz = False class ModelSizeLogger(pl.Callback): """Logs information about model size (in parameters and on disk).""" def __init__(self, print_size=True): super().__init__() self.print_size = print_size @rank_zero_only def on_fit_start(self, trainer, module): self._run(trainer, module) def _run(self, trainer, module): metrics = {} metrics["mb_disk"] = self.get_model_disksize(module) metrics["nparams"] = count_params(module) if self.print_size: print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB") metrics = {f"size/{key}": value for key, value in metrics.items()} trainer.logger.log_metrics(metrics, step=-1) @staticmethod def get_model_disksize(module): """Determine the model's size on disk by saving it to disk.""" with tempfile.NamedTemporaryFile() as f: torch.save(module.state_dict(), f) size_mb = os.path.getsize(f.name) / 1e6 return size_mb class GraphLogger(pl.Callback): """Logs a compute graph as an image.""" def __init__(self, output_key="logits"): super().__init__() self.graph_logged = False self.output_key = output_key if not has_torchviz: raise ImportError("GraphLogCallback requires torchviz." "") @rank_zero_only def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx): if not self.graph_logged: try: outputs = outputs[0][0]["extra"] self.log_graph(trainer, module, outputs[self.output_key]) except KeyError: logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}") self.graph_logged = True @staticmethod def log_graph(trainer, module, outputs): if check_and_warn(trainer.logger, "log_image", "graph"): return params_dict = dict(list(module.named_parameters())) graph = torchviz.make_dot(outputs, params=params_dict) graph.format = "png" fname = Path(trainer.logger.experiment.dir) / "graph" graph.render(fname) fname = str(fname.with_suffix("." + graph.format)) trainer.logger.log_image(key="graph", images=[fname]) def count_params(module): """Counts the number of parameters in a Torch Module.""" return sum(p.numel() for p in module.parameters()) ================================================ FILE: lab07/text_recognizer/callbacks/optim.py ================================================ import pytorch_lightning as pl KEY = "optimizer" class LearningRateMonitor(pl.callbacks.LearningRateMonitor): """Extends Lightning's LearningRateMonitor with a prefix. Logs the learning rate during training. See the docs for pl.callbacks.LearningRateMonitor for details. """ def _add_prefix(self, *args, **kwargs) -> str: return f"{KEY}/" + super()._add_prefix(*args, **kwargs) ================================================ FILE: lab07/text_recognizer/callbacks/util.py ================================================ import logging logging.basicConfig(level=logging.WARNING) def check_and_warn(logger, attribute, feature): if not hasattr(logger, attribute): warn_no_attribute(feature, attribute) return True def warn_no_attribute(blocked_feature, missing_attribute): logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.") ================================================ FILE: lab07/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines from .iam_paragraphs import IAMParagraphs from .iam_lines import IAMLines from .fake_images import FakeImageData from .iam_synthetic_paragraphs import IAMSyntheticParagraphs from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs ================================================ FILE: lab07/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab07/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab07/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab07/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab07/text_recognizer/data/fake_images.py ================================================ """A fake image dataset for testing.""" import argparse import torch import torchvision from text_recognizer.data.base_data_module import BaseDataModule _NUM_SAMPLES = 512 _IMAGE_LEN = 28 _NUM_CLASSES = 10 class FakeImageData(BaseDataModule): """Fake images dataset.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.num_samples = self.args.get("num_samples", _NUM_SAMPLES) self.input_dims = (1, self.args.get("image_height", _IMAGE_LEN), self.args.get("image_width", _IMAGE_LEN)) self.num_classes = self.args.get("num_classes", _NUM_CLASSES) self.output_dims = (self.num_classes, 1) self.mapping = list(range(0, self.num_classes)) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--num_samples", type=int, default=_NUM_SAMPLES) parser.add_argument("--num_classes", type=int, default=_NUM_CLASSES) parser.add_argument("--image_height", type=int, default=_IMAGE_LEN) parser.add_argument("--image_width", type=int, default=_IMAGE_LEN) return parser def setup(self, stage: str = None) -> None: fake_dataset = torchvision.datasets.FakeData( size=self.num_samples, image_size=self.input_dims, num_classes=self.output_dims[0], transform=torchvision.transforms.ToTensor(), ) val_size = int(self.num_samples * 0.25) self.data_train, self.data_val, self.data_test = torch.utils.data.random_split( # type: ignore dataset=fake_dataset, lengths=[self.num_samples - 2 * val_size, val_size, val_size] ) ================================================ FILE: lab07/text_recognizer/data/iam.py ================================================ """Class for loading the IAM handwritten text dataset, which encompasses both paragraphs and lines, plus utilities.""" from pathlib import Path from typing import Any, cast, Dict, List, Optional import zipfile from boltons.cacheutils import cachedproperty from defusedxml import ElementTree from PIL import Image, ImageOps import toml from text_recognizer import util from text_recognizer.data.base_data_module import _download_raw_dataset, load_and_print_info import text_recognizer.metadata.iam as metadata from text_recognizer.metadata.iam_paragraphs import NEW_LINE_TOKEN METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME EXTRACTED_DATASET_DIRNAME = metadata.EXTRACTED_DATASET_DIRNAME class IAM: """A dataset of images of handwritten text written on a form underneath a typewritten prompt. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database Images are identified by their "form ID". These IDs are used to separate train, validation and test splits, as keys for dictonaries returning label and image crop region data, and more. The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (LWITLRT): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self): self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self): if self.xml_filenames: return filename = _download_raw_dataset(self.metadata, DL_DATA_DIRNAME) # type: ignore _extract_raw_dataset(filename, DL_DATA_DIRNAME) def load_image(self, id: str) -> Image.Image: """Load and return an image of an entire IAM form. The image is grayscale with white text on black background. This image will have the printed prompt text at the top, above the handwritten text. Images of individual words or lines and of whole paragraphs can be cropped out using the relevant crop region data. """ image = util.read_image_pil(self.form_filenames_by_id[id], grayscale=True) image = ImageOps.invert(image) return image def __repr__(self): """Print info about the dataset.""" info = ["IAM Dataset"] info.append(f"Total Images: {len(self.xml_filenames)}") info.append(f"Total Test Images: {len(self.test_ids)}") info.append(f"Total Paragraphs: {len(self.paragraph_string_by_id)}") num_lines = sum(len(line_regions) for line_regions in self.line_regions_by_id.items()) info.append(f"Total Lines: {num_lines}") return "\n\t".join(info) @cachedproperty def all_ids(self): """A list of all form IDs.""" return sorted([f.stem for f in self.xml_filenames]) @cachedproperty def ids_by_split(self): return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids} @cachedproperty def split_by_id(self): """A dictionary mapping form IDs to their split according to IAM Lines LWITLRT.""" split_by_id = {id_: "train" for id_ in self.train_ids} split_by_id.update({id_: "val" for id_ in self.validation_ids}) split_by_id.update({id_: "test" for id_ in self.test_ids}) return split_by_id @cachedproperty def train_ids(self): """A list of form IDs which are in the IAM Lines LWITLRT training set.""" return list(set(self.all_ids) - (set(self.test_ids) | set(self.validation_ids))) @cachedproperty def test_ids(self): """A list of form IDs from the IAM Lines LWITLRT test set.""" return _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/testset.txt") @property def xml_filenames(self) -> List[Path]: """A list of the filenames of all .xml files, which contain label information.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @cachedproperty def validation_ids(self): """A list of form IDs from IAM Lines LWITLRT validation sets 1 and 2.""" val_ids = _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset1.txt") val_ids.extend(_get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset2.txt")) return val_ids @property def form_filenames(self) -> List[Path]: """A list of the filenames of all .jpg files, which contain images of IAM forms.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def xml_filenames_by_id(self): """A dictionary mapping form IDs to their XML label information files.""" return {filename.stem: filename for filename in self.xml_filenames} @property def form_filenames_by_id(self): """A dictionary mapping form IDs to their JPEG images.""" return {filename.stem: filename for filename in self.form_filenames} @cachedproperty def line_strings_by_id(self): """A dict mapping an IAM form id to its list of line texts.""" return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def line_regions_by_id(self): """A dict mapping an IAM form id to its list of line image crop regions.""" return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def paragraph_string_by_id(self): """A dict mapping an IAM form id to its paragraph text.""" return {id: NEW_LINE_TOKEN.join(line_strings) for id, line_strings in self.line_strings_by_id.items()} @cachedproperty def paragraph_region_by_id(self): """A dict mapping an IAM form id to its paragraph image crop region.""" return { id: { "x1": min(region["x1"] for region in line_regions), "y1": min(region["y1"] for region in line_regions), "x2": max(region["x2"] for region in line_regions), "y2": max(region["y2"] for region in line_regions), } for id, line_regions in self.line_regions_by_id.items() } def _extract_raw_dataset(filename: Path, dirname: Path) -> None: print("Extracting IAM data") with util.temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zip_file: zip_file.extractall() def _get_ids_from_lwitlrt_split_file(filename: str) -> List[str]: """Get the ids from Large Writer Independent Text Line Recognition Task (LWITLRT) data split file.""" with open(filename, "r") as f: line_ids_str = f.read() line_ids = line_ids_str.split("\n") page_ids = list({"-".join(line_id.split("-")[:2]) for line_id in line_ids if line_id}) return page_ids def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace " with ".""" xml_line_elements = _get_line_elements_from_xml_file(filename) return [_get_text_from_xml_element(el) for el in xml_line_elements] def _get_text_from_xml_element(xml_element: Any) -> str: """Extract text from any XML element.""" return xml_element.attrib["text"].replace(""", '"') def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get the line region dict for each line.""" xml_line_elements = _get_line_elements_from_xml_file(filename) line_regions = [ cast(Dict[str, int], _get_region_from_xml_element(xml_elem=el, xml_path="word/cmp")) for el in xml_line_elements ] assert any(region is not None for region in line_regions), "Line regions cannot be None" # next_line_region["y1"] - prev_line_region["y2"] can be negative due to overlapping characters line_gaps_y = [ max(next_line_region["y1"] - prev_line_region["y2"], 0) for next_line_region, prev_line_region in zip(line_regions[1:], line_regions[:-1]) ] post_line_gaps_y = line_gaps_y + [2 * metadata.LINE_REGION_PADDING] pre_line_gaps_y = [2 * metadata.LINE_REGION_PADDING] + line_gaps_y return [ { "x1": region["x1"] - metadata.LINE_REGION_PADDING, "x2": region["x2"] + metadata.LINE_REGION_PADDING, "y1": region["y1"] - min(metadata.LINE_REGION_PADDING, pre_line_gaps_y[i] // 2), "y2": region["y2"] + min(metadata.LINE_REGION_PADDING, post_line_gaps_y[i] // 2), } for i, region in enumerate(line_regions) ] def _get_line_elements_from_xml_file(filename: str) -> List[Any]: """Get all line xml elements from xml file.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec return xml_root_element.findall("handwritten-part/line") def _get_region_from_xml_element(xml_elem: Any, xml_path: str) -> Optional[Dict[str, int]]: """ Get region from input xml element. The region is downsampled because the stored images are also downsampled. Parameters ---------- xml_elem xml element can be a line or word element with x, y, width, and height attributes xml_path should be "word/cmp" if xml_elem is a line element, else "cmp" """ unit_elements = xml_elem.findall(xml_path) if not unit_elements: return None return { "x1": min(int(el.attrib["x"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y1": min(int(el.attrib["y"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "x2": max(int(el.attrib["x"]) + int(el.attrib["width"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y2": max(int(el.attrib["y"]) + int(el.attrib["height"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, } if __name__ == "__main__": load_and_print_info(IAM) ================================================ FILE: lab07/text_recognizer/data/iam_lines.py ================================================ """A dataset of lines of handwritten text derived from the IAM dataset.""" import argparse import json from pathlib import Path from typing import Sequence import numpy as np from PIL import Image, ImageFile from text_recognizer import util from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.line import IAMLineStem ImageFile.LOAD_TRUNCATED_IMAGES = True PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR class IAMLines(BaseDataModule): """Lines of text pulled from the IAM Handwriting database.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true") == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = IAMLineStem() self.trainval_transform = IAMLineStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if PROCESSED_DATA_DIRNAME.exists(): return print("Cropping IAM line regions...") iam = IAM() iam.prepare_data() crops_train, labels_train = generate_line_crops_and_labels(iam, "train") crops_val, labels_val = generate_line_crops_and_labels(iam, "val") crops_test, labels_test = generate_line_crops_and_labels(iam, "test") shapes = np.array([crop.size for crop in crops_train + crops_val + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] print("Saving images, labels, and statistics...") save_images_and_labels(crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_val, labels_val, "val", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt", "w") as file: file.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt") as file: max_aspect_ratio = float(file.read()) image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) assert image_width <= metadata.IMAGE_WIDTH if stage == "fit" or stage is None: x_train, labels_train = load_processed_crops_and_labels("train", PROCESSED_DATA_DIRNAME) y_train = convert_strings_to_labels(labels_train, self.inverse_mapping, length=self.output_dims[0]) self.data_train = BaseDataset(x_train, y_train, transform=self.trainval_transform) x_val, labels_val = load_processed_crops_and_labels("val", PROCESSED_DATA_DIRNAME) y_val = convert_strings_to_labels(labels_val, self.inverse_mapping, length=self.output_dims[0]) self.data_val = BaseDataset(x_val, y_val, transform=self.trainval_transform) # quick check: do we have the right sequence lengths? assert self.output_dims[0] >= max([len(_) for _ in labels_train]) + 2 # Add 2 for start/end tokens. assert self.output_dims[0] >= max([len(_) for _ in labels_val]) + 2 # Add 2 for start/end tokens. if stage == "test" or stage is None: x_test, labels_test = load_processed_crops_and_labels("test", PROCESSED_DATA_DIRNAME) y_test = convert_strings_to_labels(labels_test, self.inverse_mapping, length=self.output_dims[0]) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) assert self.output_dims[0] >= max([len(_) for _ in labels_test]) + 2 def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Lines Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def generate_line_crops_and_labels(iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR): """Create both cropped lines and associated labels from IAM, with resizing by default""" crops, labels = [], [] for iam_id in iam.ids_by_split[split]: labels += iam.line_strings_by_id[iam_id] image = iam.load_image(iam_id) for line in iam.line_regions_by_id[iam_id]: coords = [line[point] for point in ["x1", "y1", "x2", "y2"]] crop = image.crop(coords) crop = resize_image(crop, scale_factor=scale_factor) crops.append(crop) assert len(crops) == len(labels) return crops, labels def save_images_and_labels(crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path): (data_dirname / split).mkdir(parents=True, exist_ok=True) with open(data_dirname / split / "_labels.json", "w") as f: json.dump(labels, f) for ind, crop in enumerate(crops): crop.save(data_dirname / split / f"{ind}.png") def load_processed_crops_and_labels(split: str, data_dirname: Path): """Load line crops and labels for given split from processed directory.""" crops = load_processed_line_crops(split, data_dirname) labels = load_processed_line_labels(split, data_dirname) assert len(crops) == len(labels) return crops, labels def load_processed_line_crops(split: str, data_dirname: Path): """Load line crops for given split from processed directory.""" crop_filenames = sorted((data_dirname / split).glob("*.png"), key=lambda filename: int(Path(filename).stem)) crops = [util.read_image_pil(filename, grayscale=True) for filename in crop_filenames] return crops def load_processed_line_labels(split: str, data_dirname: Path): """Load line labels for given split from processed directory.""" with open(data_dirname / split / "_labels.json") as file: labels = json.load(file) return labels if __name__ == "__main__": load_and_print_info(IAMLines) ================================================ FILE: lab07/text_recognizer/data/iam_original_and_synthetic_paragraphs.py ================================================ """IAM Original and Synthetic Paragraphs Dataset class.""" import argparse from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs class IAMOriginalAndSyntheticParagraphs(BaseDataModule): """A concatenation of original and synthetic IAM paragraph datasets.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.iam_paragraphs = IAMParagraphs(args) self.iam_syn_paragraphs = IAMSyntheticParagraphs(args) self.input_dims = self.iam_paragraphs.input_dims self.output_dims = self.iam_paragraphs.output_dims self.mapping = self.iam_paragraphs.mapping self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") IAMSyntheticParagraphs.add_to_argparse(parser) return parser def prepare_data(self, *args, **kwargs) -> None: self.iam_paragraphs.prepare_data() self.iam_syn_paragraphs.prepare_data() def setup(self, stage: str = None) -> None: self.iam_paragraphs.setup(stage) self.iam_syn_paragraphs.setup(stage) if stage == "fit" or stage is None: self.data_train = ConcatDataset([self.iam_paragraphs.data_train, self.iam_syn_paragraphs.data_train]) self.data_val = self.iam_paragraphs.data_val if stage == "test" or stage is None: self.data_test = self.iam_paragraphs.data_test def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data if __name__ == "__main__": load_and_print_info(IAMOriginalAndSyntheticParagraphs) ================================================ FILE: lab07/text_recognizer/data/iam_paragraphs.py ================================================ """IAM Paragraphs Dataset class.""" import argparse import json from pathlib import Path from typing import Callable, Dict, Optional, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.paragraph import ParagraphStem IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME class IAMParagraphs(BaseDataModule): """IAM Handwriting database paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true").lower() == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = ParagraphStem() self.trainval_transform = ParagraphStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): return rank_zero_info( "IAMParagraphs.prepare_data: Cropping IAM paragraph regions and saving them along with labels..." ) iam = IAM() iam.prepare_data() properties = {} for split in ["train", "val", "test"]: crops, labels = get_paragraph_crops_and_labels(iam=iam, split=split) save_crops_and_labels(crops=crops, labels=labels, split=split) properties.update( { id_: { "crop_shape": crops[id_].size[::-1], "label_length": len(label), "num_lines": _num_lines(label), } for id_, label in labels.items() } ) with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: def _load_dataset(split: str, transform: Callable) -> BaseDataset: crops, labels = load_processed_crops_and_labels(split) Y = convert_strings_to_labels(strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]) return BaseDataset(crops, Y, transform=transform) rank_zero_info(f"IAMParagraphs.setup({stage}): Loading IAM paragraph regions and lines...") validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) if stage == "fit" or stage is None: self.data_train = _load_dataset(split="train", transform=self.trainval_transform) self.data_val = _load_dataset(split="val", transform=self.transform) if stage == "test" or stage is None: self.data_test = _load_dataset(split="test", transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def validate_input_and_output_dimensions( input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] ) -> None: """Validate input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] # Add 2 because of start and end tokens assert output_dims is not None and output_dims[0] >= properties["label_length"]["max"] + 2 def get_paragraph_crops_and_labels( iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR ) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: """Create IAM paragraph crops and labels for a given split, with resizing.""" crops = {} labels = {} for iam_id in iam.ids_by_split[split]: image = iam.load_image(iam_id) para_region = iam.paragraph_region_by_id[iam_id] crops[iam_id] = image.crop([para_region[_] for _ in ["x1", "y1", "x2", "y2"]]) crops[iam_id] = resize_image(crops[iam_id], scale_factor=scale_factor) labels[iam_id] = iam.paragraph_string_by_id[iam_id] assert len(crops) == len(labels) return crops, labels def save_crops_and_labels(crops: Dict[str, Image.Image], labels: Dict[str, str], split: str): """Save crops, labels and shapes of crops of a split.""" (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with open(_labels_filename(split), "w") as f: json.dump(labels, f, indent=4) for id_, crop in crops.items(): crop.save(_crop_filename(id_, split)) def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: """Load processed crops and labels for given split.""" with open(_labels_filename(split), "r") as f: labels = json.load(f) sorted_ids = sorted(labels.keys()) ordered_crops = [Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids] ordered_labels = [labels[id_] for id_ in sorted_ids] assert len(ordered_crops) == len(ordered_labels) return ordered_crops, ordered_labels def get_dataset_properties() -> dict: """Return properties describing the overall dataset.""" with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: properties = json.load(f) def _get_property_values(key: str) -> list: return [_[key] for _ in properties.values()] crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count(NEW_LINE_TOKEN) + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs) ================================================ FILE: lab07/text_recognizer/data/iam_synthetic_paragraphs.py ================================================ """IAM Synthetic Paragraphs Dataset class.""" import argparse import random from typing import Any, Callable, List, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info import torch from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( generate_line_crops_and_labels, load_processed_line_crops, load_processed_line_labels, save_images_and_labels, ) from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.util import convert_strings_to_labels import text_recognizer.metadata.iam_synthetic_paragraphs as metadata NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME DATASET_LEN = metadata.DATASET_LEN class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database synthetic paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.line_crops = None self.line_labels = None self.dataset_len = self.args.get("dataset_len", DATASET_LEN) def prepare_data(self, *args, **kwargs) -> None: """ Prepare IAM lines such that they can be used to generate synthetic paragraphs dataset in setup(). This method is IAMLines.prepare_data + resizing of line crops. """ if PROCESSED_DATA_DIRNAME.exists(): return rank_zero_info( "IAMSyntheticParagraphs.prepare_data: preparing IAM lines for synthetic IAM paragraph creation..." ) iam = IAM() iam.prepare_data() for split in ["train"]: # synthetic dataset is only used in training phase rank_zero_info(f"Cropping IAM line regions and loading labels for {split} data split...") crops, labels = generate_line_crops_and_labels(iam, split) save_images_and_labels(crops, labels, split, PROCESSED_DATA_DIRNAME) def setup(self, stage: str = None) -> None: rank_zero_info(f"IAMSyntheticParagraphs.setup({stage}): Loading train IAM paragraph regions and lines...") if stage == "fit" or stage is None: self._load_processed_crops_and_labels() self.data_train = IAMSyntheticParagraphsDataset( line_crops=self.line_crops, line_labels=self.line_labels, dataset_len=self.dataset_len, inverse_mapping=self.inverse_mapping, input_dims=self.input_dims, output_dims=self.output_dims, transform=self.trainval_transform, ) def _load_processed_crops_and_labels(self): if self.line_crops is None: self.line_crops = load_processed_line_crops("train", PROCESSED_DATA_DIRNAME) if self.line_labels is None: self.line_labels = load_processed_line_labels("train", PROCESSED_DATA_DIRNAME) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Synthetic Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def add_to_argparse(parser): parser.add_argument("--dataset_len", type=int, default=DATASET_LEN) return parser class IAMSyntheticParagraphsDataset(torch.utils.data.Dataset): """Dataset of synthetic paragraphs built out of individual IAM lines.""" def __init__( self, line_crops: List[Image.Image], line_labels: List[str], dataset_len: int, inverse_mapping: dict, input_dims: Tuple[int, ...], output_dims: Tuple[int, ...], transform: Callable = None, ) -> None: super().__init__() self.line_crops = line_crops self.line_labels = line_labels assert len(self.line_crops) == len(self.line_labels) self.ids = list(range(len(self.line_labels))) self.dataset_len = dataset_len self.inverse_mapping = inverse_mapping self.input_dims = input_dims self.output_dims = output_dims self.transform = transform self.min_num_lines, self.max_num_lines = 1, 15 self.seed_set = False def __len__(self) -> int: """Return length of the dataset.""" return self.dataset_len def _set_seed(self, seed): if not self.seed_set: print(f"Setting seed to {seed} for worker {torch.utils.data.get_worker_info()}") random.seed(seed) self.seed_set = True def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a random paragraph, using the first index as a seed.""" # Since shuffle is True for train dataloaders, the first index will be different on different GPUs self._set_seed(index) num_lines = random.randint(self.min_num_lines, self.max_num_lines) indices = random.sample(self.ids, k=num_lines) while True: datum = join_line_crops_to_form_paragraph([self.line_crops[i] for i in indices]) labels = NEW_LINE_TOKEN.join([self.line_labels[i] for i in indices]) if ( (len(labels) <= self.output_dims[0] - 2) and (datum.height <= self.input_dims[1]) and (datum.width <= self.input_dims[2]) ): break indices = indices[:-1] if self.transform is not None: datum = self.transform(datum) length = self.output_dims[0] target = convert_strings_to_labels(strings=[labels], mapping=self.inverse_mapping, length=length)[0] return datum, target def join_line_crops_to_form_paragraph(line_crops: Sequence[Image.Image]) -> Image.Image: """Horizontally stack line crops and return a single image forming the paragraph.""" crop_shapes = np.array([_.size[::-1] for _ in line_crops]) para_height = crop_shapes[:, 0].sum() para_width = crop_shapes[:, 1].max() para_image = Image.new(mode="L", size=(para_width, para_height), color=0) current_height = 0 for line_crop in line_crops: para_image.paste(line_crop, box=(0, current_height)) current_height += line_crop.height return para_image if __name__ == "__main__": load_and_print_info(IAMSyntheticParagraphs) ================================================ FILE: lab07/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab07/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab07/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab07/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel from .transformer import TransformerLitModel ================================================ FILE: lab07/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy from .metrics import CharacterErrorRate OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def add_on_first_batch(self, metrics, outputs, batch_idx): if batch_idx == 0: outputs.update(metrics) def add_on_logged_batches(self, metrics, outputs): if self.is_logged_batch: outputs.update(metrics) def is_logged_batch(self): if self.trainer is None: return False else: return self.trainer._logger_connector.should_update_logs class BaseImageToTextLitModel(BaseLitModel): # pylint: disable=too-many-ancestors """Base class for ImageToText models in PyTorch Lightning.""" def __init__(self, model, args: argparse.Namespace = None): super().__init__(model, args) self.model = model self.args = vars(args) if args is not None else {} self.inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} self.start_index = self.inverse_mapping[""] self.end_index = self.inverse_mapping[""] self.padding_index = self.inverse_mapping["

"] self.ignore_tokens = [self.start_index, self.end_index, self.padding_index] self.val_cer = CharacterErrorRate(self.ignore_tokens) self.test_cer = CharacterErrorRate(self.ignore_tokens) ================================================ FILE: lab07/text_recognizer/lit_models/metrics.py ================================================ """Special-purpose metrics for tracking our model performance.""" from typing import Sequence import torch import torchmetrics class CharacterErrorRate(torchmetrics.CharErrorRate): """Character error rate metric, allowing for tokens to be ignored.""" def __init__(self, ignore_tokens: Sequence[int], *args): super().__init__(*args) self.ignore_tokens = set(ignore_tokens) def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()] targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()] super().update(preds_l, targets_l) def test_character_error_rate(): metric = CharacterErrorRate([0, 1]) X = torch.tensor( [ [0, 2, 2, 3, 3, 1], # error will be 0 [0, 2, 1, 1, 1, 1], # error will be .75 [0, 2, 2, 4, 4, 1], # error will be .5 ] ) Y = torch.tensor( [ [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], ] ) metric(X, Y) assert metric.compute() == sum([0, 0.75, 0.5]) / 3 if __name__ == "__main__": test_character_error_rate() ================================================ FILE: lab07/text_recognizer/lit_models/transformer.py ================================================ """An encoder-decoder Transformer model""" from typing import List, Sequence import torch from .base import BaseImageToTextLitModel from .util import replace_after class TransformerLitModel(BaseImageToTextLitModel): """ Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module. The module must implement an encode and decode method, and the forward method should be the forward pass during production inference. """ def __init__(self, model, args=None): super().__init__(model, args) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index) def forward(self, x): return self.model(x) def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Uses provided sequence y as guide for non-autoregressive encoding-decoding of x. Parameters ---------- x Batch of images to be encoded. See self.model.encode for shape information. y Batch of ground truth output sequences. Returns ------- torch.Tensor (B, C, Sy) logits """ x = self.model.encode(x) output = self.model.decode(x, y) # (Sy, B, C) return output.permute(1, 2, 0) # (B, C, Sy) def training_step(self, batch, batch_idx): x, y = batch logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("train/loss", loss) outputs = {"loss": loss} if self.is_logged_batch(): preds = self.get_preds(logits) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs}) return outputs def validation_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("test/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def map(self, ks: Sequence[int], ignore: bool = True) -> str: """Maps an iterable of integers to a string using the lit model's mapping.""" if ignore: return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens]) else: return "".join([self.mapping[k] for k in ks]) def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]: """Maps a list of lists of integers to a list of strings using the lit model's mapping.""" return [self.map(k, ignore) for k in ks] def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor: """Converts logit-like Tensors into prediction indices, optionally overwritten after end token index. Parameters ---------- logitlikes (B, C, Sy) Tensor with classes as second dimension. The largest value is the one whose index we will return. Logits, logprobs, and probs are all acceptable. replace_after_end Whether to replace values after the first appearance of the end token with the padding token. Returns ------- torch.Tensor (B, Sy) Tensor of integers in [0, C-1] representing predictions. """ raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy) if replace_after_end: return replace_after(raw, self.end_index, self.padding_index) # (B, Sy) else: return raw # (B, Sy) ================================================ FILE: lab07/text_recognizer/lit_models/util.py ================================================ from typing import Union import torch def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: """Return indices of first appearance of element in x, collapsing along dim. Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 Parameters ---------- x One or two-dimensional Tensor to search for element. element Item to search for inside x. dim Dimension of Tensor to collapse over. Returns ------- torch.Tensor Indices where element occurs in x. If element is not found, return length of x along dim. One dimension smaller than x. Raises ------ ValueError if x is not a 1 or 2 dimensional Tensor Examples -------- >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3) tensor([2, 1, 3, 0]) >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0) tensor(0) """ if x.dim() > 2 or x.dim() == 0: raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}") matches = x == element first_appearance_mask = (matches.cumsum(dim) == 1) & matches does_match, match_index = first_appearance_mask.max(dim) first_inds = torch.where(does_match, match_index, x.shape[dim]) return first_inds def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor: """Replace all values in each row of 2d Tensor x after the first appearance of element with replace. Parameters ---------- x Two-dimensional Tensor (shape denoted (B, S)) to replace values in. element Item to search for inside x. replace Item that replaces entries that appear after element. Returns ------- outs New Tensor of same shape as x with values after element replaced. Examples -------- >>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4) tensor([[1, 2, 3], [2, 3, 4], [1, 1, 1], [3, 4, 4]]) """ first_appearances = first_appearance(x, element, dim=1) # (B,) indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,) outs = torch.where( indices[None, :] <= first_appearances[:, None], # if index is before first appearance x, # return the value from x replace, # otherwise, return the replacement value ) return outs # (B, S) ================================================ FILE: lab07/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab07/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab07/text_recognizer/metadata/iam.py ================================================ import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates ================================================ FILE: lab07/text_recognizer/metadata/iam_lines.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines" IMAGE_SCALE_FACTOR = 2 CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (89, 1) MAPPING = emnist.MAPPING ================================================ FILE: lab07/text_recognizer/metadata/iam_paragraphs.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN] # must match IMAGE_SCALE_FACTOR for IAMLines to be compatible with synthetic paragraphs IMAGE_SCALE_FACTOR = 2 IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640 IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) MAX_LABEL_LENGTH = 682 DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) ================================================ FILE: lab07/text_recognizer/metadata/iam_synthetic_paragraphs.py ================================================ import text_recognizer.metadata.iam_paragraphs as iam_paragraphs import text_recognizer.metadata.shared as shared NEW_LINE_TOKEN = iam_paragraphs.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs" EXPECTED_BATCH_SIZE = 64 EXPECTED_GPUS = 8 EXPECTED_STEPS = 40 # set the dataset's length based on parameters during typical training DATASET_LEN = EXPECTED_BATCH_SIZE * EXPECTED_GPUS * EXPECTED_STEPS ================================================ FILE: lab07/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab07/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab07/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple from .resnet_transformer import ResnetTransformer from .line_cnn_transformer import LineCNNTransformer ================================================ FILE: lab07/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab07/text_recognizer/models/line_cnn.py ================================================ """Basic building blocks for convolutional models over lines of text.""" import argparse import math from typing import Any, Dict, Tuple, Union import torch from torch import nn import torch.nn.functional as F # Common type hints Param2D = Union[int, Tuple[int, int]] CONV_DIM = 32 FC_DIM = 512 FC_DROPOUT = 0.2 WINDOW_WIDTH = 16 WINDOW_STRIDE = 8 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__( self, input_channels: int, output_channels: int, kernel_size: Param2D = 3, stride: Param2D = 1, padding: Param2D = 1, ) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class LineCNN(nn.Module): """ Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits """ def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.args = vars(args) if args is not None else {} self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] _C, H, _W = data_config["input_dims"] conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) # Input is (1, H, W) self.convs = nn.Sequential( ConvBlock(1, conv_dim), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim, stride=2), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim * 2, stride=2), ConvBlock(conv_dim * 2, conv_dim * 2), ConvBlock(conv_dim * 2, conv_dim * 4, stride=2), ConvBlock(conv_dim * 4, conv_dim * 4), ConvBlock( conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0 ), ) self.fc1 = nn.Linear(fc_dim, fc_dim) self.dropout = nn.Dropout(fc_dropout) self.fc2 = nn.Linear(fc_dim, self.num_classes) self._init_weights() def _init_weights(self): """ Initialize weights in a better way than default. See https://github.com/pytorch/pytorch/issues/18182 """ for m in self.modules(): if type(m) in { nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear, }: nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu") if m.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(m.bias, -bound, bound) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the LineCNN to a black-and-white input image. Parameters ---------- x (B, 1, H, W) input image Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and self.window_width C is self.num_classes """ _B, _C, _H, _W = x.shape x = self.convs(x) # (B, FC_DIM, 1, Sx) x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM) x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM) x = self.dropout(x) x = self.fc2(x) # (B, S, C) x = x.permute(0, 2, 1) # -> (B, C, S) if self.limit_output_length: x = x[:, :, : self.output_length] return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab07/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab07/text_recognizer/models/line_cnn_transformer.py ================================================ """Model that combines a LineCNN with a Transformer model for text prediction.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .line_cnn import LineCNN from .transformer_util import generate_square_subsequent_mask, PositionalEncoding TF_DIM = 256 TF_FC_DIM = 256 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 class LineCNNTransformer(nn.Module): """Process the line through a CNN and process the resulting sequence with a Transformer decoder.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # Instantiate LineCNN with "num_classes" set to self.dim data_config_for_line_cnn = {**data_config} data_config_for_line_cnn["mapping"] = list(range(self.dim)) self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args) # LineCNN outputs (B, E, S) log probs, with E == dim self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.pos_encoder = PositionalEncoding(d_model=self.dim) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (Sx, B, E) logits """ x = self.line_cnn(x) # (B, E, Sx) x = x * math.sqrt(self.dim) x = x.permute(2, 0, 1) # (Sx, B, E) x = self.pos_encoder(x) # (Sx, B, E) return x def decode(self, x, y): """Decode a batch of encoded images x using preceding ground truth y. Parameters ---------- x (Sx, B, E) image encoded as a sequence y (B, Sy) with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) logits """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output def forward(self, x: torch.Tensor) -> torch.Tensor: """Predict sequences of tokens from input images auto-regressively. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1:] # Set the last output token # Set all tokens after end token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) @staticmethod def add_to_argparse(parser): LineCNN.add_to_argparse(parser) parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab07/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab07/text_recognizer/models/resnet_transformer.py ================================================ """Model combining a ResNet with a Transformer for image-to-sequence tasks.""" import argparse import math from typing import Any, Dict import torch from torch import nn import torchvision from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage TF_DIM = 256 TF_FC_DIM = 1024 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 RESNET_DIM = 512 # hard-coded class ResnetTransformer(nn.Module): """Pass an image through a Resnet and decode the resulting embedding with a Transformer.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) self.mapping = data_config["mapping"] inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # ## Encoder part - should output vector sequence of length self.dim per sample resnet = torchvision.models.resnet18(weights=None) self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) # encoder_projection will output (B, dim, _H, _W) logits self.enc_pos_encoder = PositionalEncodingImage( d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] ) # Max (Ho, Wo) # ## Decoder part self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def forward(self, x: torch.Tensor) -> torch.Tensor: """Autoregressively produce sequences of labels from input images. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- output_tokens (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1] # Set the last output token # Early stopping of prediction loop to speed up prediction if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): break # Set all tokens after end or padding token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") if self.encoder_projection.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.encoder_projection.bias, -bound, bound) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- (Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps """ _B, C, _H, _W = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo return x def decode(self, x, y): """Decode a batch of encoded images x with guiding sequences y. During autoregressive inference, the guiding sequence will be previous predictions. During training, the guiding sequence will be the ground truth. Parameters ---------- x (Sx, B, E) images encoded as sequences of embeddings y (B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) batch of logit sequences """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.dec_pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output @staticmethod def add_to_argparse(parser): parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab07/text_recognizer/models/transformer_util.py ================================================ """Position Encoding and other utilities for Transformers.""" import math import torch from torch import Tensor import torch.nn as nn class PositionalEncodingImage(nn.Module): """ Module used to add 2-D positional encodings to the feature-map produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor: pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """pytorch.nn.module.forward""" # x.shape = (B, d_model, H, W) assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding(torch.nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None: super().__init__() self.dropout = torch.nn.Dropout(p=dropout) pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_len: int) -> torch.Tensor: pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: # x.shape = (S, B, d_model) assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) def generate_square_subsequent_mask(size: int) -> torch.Tensor: """Generate a triangular (size, size) mask.""" mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask ================================================ FILE: lab07/text_recognizer/paragraph_text_recognizer.py ================================================ """Detects a paragraph of text in an input image. Example usage as a script: python text_recognizer/paragraph_text_recognizer.py \ text_recognizer/tests/support/paragraphs/a01-077.png python text_recognizer/paragraph_text_recognizer.py \ https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png """ import argparse from pathlib import Path from typing import Sequence, Union from PIL import Image import torch from text_recognizer import util from text_recognizer.stems.paragraph import ParagraphStem STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "paragraph-text-recognizer" MODEL_FILE = "model.pt" class ParagraphTextRecognizer: """Recognizes a paragraph of text in an image.""" def __init__(self, model_path=None): if model_path is None: model_path = STAGED_MODEL_DIRNAME / MODEL_FILE self.model = torch.jit.load(model_path) self.mapping = self.model.mapping self.ignore_tokens = self.model.ignore_tokens self.stem = ParagraphStem() @torch.no_grad() def predict(self, image: Union[str, Path, Image.Image]) -> str: """Predict/infer text in input image (which can be a file path or url).""" image_pil = image if not isinstance(image, Image.Image): image_pil = util.read_image_pil(image, grayscale=True) image_tensor = self.stem(image_pil).unsqueeze(axis=0) y_pred = self.model(image_tensor)[0] pred_str = convert_y_label_to_string(y=y_pred, mapping=self.mapping, ignore_tokens=self.ignore_tokens) return pred_str def convert_y_label_to_string(y: torch.Tensor, mapping: Sequence[str], ignore_tokens: Sequence[int]) -> str: return "".join([mapping[i] for i in y if i not in ignore_tokens]) def main(): parser = argparse.ArgumentParser(description=__doc__.split("\n")[0]) parser.add_argument( "filename", type=str, help="Name for an image file. This can be a local path, a URL, a URI from AWS/GCP/Azure storage, an HDFS path, or any other resource locator supported by the smart_open library.", ) args = parser.parse_args() text_recognizer = ParagraphTextRecognizer() pred_str = text_recognizer.predict(args.filename) print(pred_str) if __name__ == "__main__": main() ================================================ FILE: lab07/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab07/text_recognizer/stems/line.py ================================================ import random from PIL import Image from torchvision import transforms import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.image import ImageStem class LineStem(ImageStem): """A stem for handling images containing a line of text.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.5, 1)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "translate": (0, 0.05), "scale": (0.4, 1.1), "shear": (-40, 50), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } if augment: self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] ) class IAMLineStem(ImageStem): """A stem for handling images containing lines of text from the IAMLines dataset.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() def embed_crop(crop, augment=augment): # crop is PIL.image of dtype="L" (so values range from 0 -> 255) image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) # Resize crop crop_width, crop_height = crop.size new_crop_height = metadata.IMAGE_HEIGHT new_crop_width = int(new_crop_height * (crop_width / crop_height)) if augment: # Add random stretching new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR) # Embed in the image x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) y = metadata.IMAGE_HEIGHT - new_crop_height image.paste(crop_resized, (x, y)) return image if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.8, 1.6)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 1, "shear": (-30, 20), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } pil_transforms_list = [transforms.Lambda(embed_crop)] if augment: pil_transforms_list += [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] self.pil_transforms = transforms.Compose(pil_transforms_list) ================================================ FILE: lab07/text_recognizer/stems/paragraph.py ================================================ """IAMParagraphs Stem class.""" import torchvision.transforms as transforms import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.image import ImageStem IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH class ParagraphStem(ImageStem): """A stem for handling images that contain a paragraph of text.""" def __init__( self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None, random_perspective_kwargs=None, gaussian_blur_kwargs=None, sharpness_kwargs=None, ): super().__init__() if not augment: self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)]) else: if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "shear": 6, "scale": (0.95, 1), "interpolation": transforms.InterpolationMode.BILINEAR, } if random_perspective_kwargs is None: random_perspective_kwargs = { "distortion_scale": 0.2, "p": 0.5, "interpolation": transforms.InterpolationMode.BILINEAR, } if gaussian_blur_kwargs is None: gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} if sharpness_kwargs is None: sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} # IMAGE_SHAPE is (576, 640) self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomCrop( size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant" ), transforms.RandomAffine(**random_affine_kwargs), transforms.RandomPerspective(**random_perspective_kwargs), transforms.GaussianBlur(**gaussian_blur_kwargs), transforms.RandomAdjustSharpness(**sharpness_kwargs), ] ) ================================================ FILE: lab07/text_recognizer/tests/test_callback_utils.py ================================================ """Tests for the text_recognizer.callbacks.util module.""" import random import string import tempfile import pytorch_lightning as pl from text_recognizer.callbacks.util import check_and_warn def test_check_and_warn_simple(): """Test the success and failure in the case of a simple class we control.""" class Foo: pass # a class with no special attributes letters = string.ascii_lowercase random_attribute = "".join(random.choices(letters, k=10)) assert check_and_warn(Foo(), random_attribute, "random feature") assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects") def test_check_and_warn_tblogger(): """Test that we return a truthy value when trying to log tables with TensorBoard. We added check_and_warn in order to prevent a crash if this happens. """ tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory()) assert check_and_warn(tblogger, "log_table", "tables") def test_check_and_warn_wandblogger(): """Test that we return a falsy value when we try to log tables with W&B. In adding check_and_warn, we don't want to block the feature in the happy path. """ wandblogger = pl.loggers.WandbLogger(anonymous=True) assert not check_and_warn(wandblogger, "log_table", "tables") ================================================ FILE: lab07/text_recognizer/tests/test_iam.py ================================================ """Test for data.iam module.""" from text_recognizer.data.iam import IAM def test_iam_parsed_lines(): """Tests that we retrieve the same number of line labels and line image cropregions.""" iam = IAM() iam.prepare_data() for iam_id in iam.all_ids: assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id]) def test_iam_data_splits(): """Fails when any identifiers are shared between training, test, or validation.""" iam = IAM() iam.prepare_data() assert not set(iam.train_ids) & set(iam.validation_ids) assert not set(iam.train_ids) & set(iam.test_ids) assert not set(iam.validation_ids) & set(iam.test_ids) ================================================ FILE: lab07/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def encode_b64_image(image, format="png"): """Encode a PIL image as a base64 string.""" _buffer = BytesIO() # bytes that live in memory image.save(_buffer, format=format) # but which we write to like a file encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8") return encoded_image def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab07/training/__init__.py ================================================ ================================================ FILE: lab07/training/cleanup_artifacts.py ================================================ """Removes artifacts from projects and runs. Artifacts are binary files that we want to track and version but don't want to include in git, generally because they are too large, because they don't have meaningful diffs, or because they change more quickly than code. During development, we often generate artifacts that we don't really need, e.g. model weights for an overfitting test run. Space on artifact storage is generally very large, but it is limited, so we should occasionally delete unneeded artifacts to reclaim some of that space. For usage help, run python training/cleanup_artifacts.py --help """ import argparse import wandb api = wandb.Api() DEFAULT_PROJECT = "fsdl-text-recognizer-2022-training" DEFAULT_ENTITY = api.default_entity def _setup_parser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--entity", type=str, default=None, help="The entity from which to remove artifacts. Provide the value DEFAULT " + f"to use the default WANDB_ENTITY, which is currently {DEFAULT_ENTITY}.", ) parser.add_argument( "--project", type=str, default=DEFAULT_PROJECT, help=f"The project from which to remove artifacts. Default is {DEFAULT_PROJECT}", ) parser.add_argument( "--run_ids", type=str, default=None, nargs="*", help="One or more run IDs from which to remove artifacts. Default is None.", ) parser.add_argument( "--run_name_res", type=str, default=None, nargs="*", help="One or more regular expressions to use to select runs (by display name) from which to remove artifacts. See wandb.Api.runs documentation for details on the syntax. Beware that this is a footgun and consider using interactively with --dryrun and -v. Default is None.", metavar="RUN_NAME_REGEX", ) flags = parser.add_mutually_exclusive_group() flags.add_argument("--all", action="store_true", help="Delete all artifacts from selected runs.") flags.add_argument( "--no-alias", action="store_true", help="Delete all artifacts without an alias from selected runs." ) flags.add_argument( "--aliases", type=str, nargs="*", help="Delete artifacts that have any of the aliases from the provided list from selected runs.", ) parser.add_argument( "-v", action="store_true", dest="verbose", help="Display information about targeted entities, projects, runs, and artifacts.", ) parser.add_argument( "--dryrun", action="store_true", help="Select artifacts without deleting them and display which artifacts were selected.", ) return parser def main(args): entity = _get_entity_from(args) project_path = f"{entity}/{args.project}" runs = _get_runs(project_path, args.run_ids, args.run_name_res, verbose=args.verbose) artifact_selector = _get_selector_from(args) protect_aliases = args.no_alias # avoid deletion of any aliased artifacts for run in runs: clean_run_artifacts( run, selector=artifact_selector, protect_aliases=protect_aliases, verbose=args.verbose, dryrun=args.dryrun ) def clean_run_artifacts(run, selector, protect_aliases=True, verbose=False, dryrun=True): artifacts = run.logged_artifacts() for artifact in artifacts: if selector(artifact): remove_artifact(artifact, protect_aliases=protect_aliases, verbose=verbose, dryrun=dryrun) def remove_artifact(artifact, protect_aliases, verbose=False, dryrun=True): project, entity, id = artifact.project, artifact.entity, artifact.id type, aliases = artifact.type, artifact.aliases if verbose or dryrun: print(f"selecting for deletion artifact {project}/{entity}/{id} of type {type} with aliases {aliases}") if not dryrun: artifact.delete(delete_aliases=not protect_aliases) def _get_runs(project_path, run_ids=None, run_name_res=None, verbose=False): if run_ids is None: run_ids = [] if run_name_res is None: run_name_res = [] runs = [] for run_id in run_ids: runs.append(_get_run_by_id(project_path, run_id, verbose=verbose)) for run_name_re in run_name_res: runs += _get_runs_by_name_re(project_path, run_name_re, verbose=verbose) return runs def _get_run_by_id(project_path, run_id, verbose=False): path = f"{project_path}/{run_id}" run = api.run(path) if verbose: print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}") return run def _get_runs_by_name_re(project_path, run_name_re, verbose=False): matching_runs = api.runs(path=project_path, filters={"display_name": {"$regex": run_name_re}}) if verbose: for run in matching_runs: print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}") return matching_runs def _get_selector_from(args, verbose=False): if args.all: if verbose: print("removing all artifacts from matching runs") return lambda _: True if args.no_alias: if verbose: print("removing all artifacts with no aliases from matching runs") return lambda artifact: artifact.aliases == [] if args.aliases: if verbose: print(f"removing all artifacts with any of {args.aliases} in aliases from matching runs") return lambda artifact: any(alias in artifact.aliases for alias in args.aliases) if verbose: print("removing no artifacts matching runs") return lambda _: False def _get_entity_from(args, verbose=False): entity = args.entity if entity is None: raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.") elif entity == "DEFAULT": entity = DEFAULT_ENTITY if verbose: print(f"using default entity {entity}") else: if verbose: print(f"using entity {entity}") return entity if __name__ == "__main__": parser = _setup_parser() args = parser.parse_args() main(args) ================================================ FILE: lab07/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import callbacks as cb from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--wandb", action="store_true", default=False, help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.", ) parser.add_argument( "--profile", action="store_true", default=False, help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.", ) parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.loss == "transformer": lit_model_class = lit_models.TransformerLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" if goldstar_metric == "validation/cer": filename_format += "-validation.cer={validation/cer:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.wandb: logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train") logger.watch(model, log_freq=max(100, args.log_every_n_steps)) logger.log_hyperparams(vars(args)) experiment_dir = logger.experiment.dir callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) if args.wandb and args.loss in ("transformer",): callbacks.append(cb.ImageToTextLogger()) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) if args.profile: sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0) profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir) profiler.STEP_FUNCTIONS = {"training_step"} # only profile training else: profiler = pl.profiler.PassThroughProfiler() trainer.profiler = profiler trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") if args.wandb: rank_zero_info("Best model also uploaded to W&B ") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab07/training/stage_model.py ================================================ """Stages a model for use in production. If based on a checkpoint, the model is converted to torchscript, saved locally, and uploaded to W&B. If based on a model that is already converted and uploaded, the model file is downloaded locally. For details on how the W&B artifacts backing the checkpoints and models are handled, see the documenation for stage_model.find_artifact. """ import argparse from pathlib import Path import tempfile import torch import wandb from text_recognizer.lit_models import TransformerLitModel from training.util import setup_data_and_model_from_args # these names are all set by the pl.loggers.WandbLogger MODEL_CHECKPOINT_TYPE = "model" BEST_CHECKPOINT_ALIAS = "best" MODEL_CHECKPOINT_PATH = "model.ckpt" LOG_DIR = Path("training") / "logs" STAGED_MODEL_TYPE = "prod-ready" # we can choose the name of this type, and ideally it's different from checkpoints STAGED_MODEL_FILENAME = "model.pt" # standard nomenclature; pytorch_model.bin is also used PROJECT_ROOT = Path(__file__).resolve().parents[1] LITMODEL_CLASS = TransformerLitModel api = wandb.Api() DEFAULT_ENTITY = api.default_entity DEFAULT_FROM_PROJECT = "fsdl-text-recognizer-2022-training" DEFAULT_TO_PROJECT = "fsdl-text-recognizer-2022-training" DEFAULT_STAGED_MODEL_NAME = "paragraph-text-recognizer" PROD_STAGING_ROOT = PROJECT_ROOT / "text_recognizer" / "artifacts" def main(args): prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name prod_staging_directory.mkdir(exist_ok=True, parents=True) entity = _get_entity_from(args) # if we're just fetching an already compiled model if args.fetch: # find it and download it staged_model = f"{entity}/{args.from_project}/{args.staged_model_name}:latest" artifact = download_artifact(staged_model, prod_staging_directory) print_info(artifact) return # and we're done # otherwise, we'll need to download the weights, compile the model, and save it with wandb.init( job_type="stage", project=args.to_project, dir=LOG_DIR ): # log staging to W&B so prod and training are connected # find the model checkpoint and retrieve its artifact name and an api handle ckpt_at, ckpt_api = find_artifact( entity, args.from_project, type=MODEL_CHECKPOINT_TYPE, alias=args.ckpt_alias, run=args.run ) # get the run that produced that checkpoint logging_run = get_logging_run(ckpt_api) print_info(ckpt_api, logging_run) metadata = get_checkpoint_metadata(logging_run, ckpt_api) # create an artifact for the staged, deployable model staged_at = wandb.Artifact(args.staged_model_name, type=STAGED_MODEL_TYPE, metadata=metadata) with tempfile.TemporaryDirectory() as tmp_dir: # download the checkpoint to a temporary directory download_artifact(ckpt_at, tmp_dir) # reload the model from that checkpoint model = load_model_from_checkpoint(metadata, directory=tmp_dir) # save the model to torchscript in the staging directory save_model_to_torchscript(model, directory=prod_staging_directory) # upload the staged model so it can be downloaded elsewhere upload_staged_model(staged_at, from_directory=prod_staging_directory) def find_artifact(entity: str, project: str, type: str, alias: str, run=None): """Finds the artifact of a given type with a given alias under the entity and project. Parameters ---------- entity The name of the W&B entity under which the artifact is logged. project The name of the W&B project under which the artifact is logged. type The name of the type of the artifact. alias : str The alias for this artifact. This alias must be unique within the provided type for the run, if provided, or for the project, if the run is not provided. run : str Optionally, the run in which the artifact is located. Returns ------- Tuple[path, artifact] An identifying path and an API handle for a matching artifact. """ if run is not None: path = _find_artifact_run(entity, project, type=type, run=run, alias=alias) else: path = _find_artifact_project(entity, project, type=type, alias=alias) return path, api.artifact(path) def get_logging_run(artifact): api_run = artifact.logged_by() return api_run def print_info(artifact, run=None): if run is None: run = get_logging_run(artifact) full_artifact_name = f"{artifact.entity}/{artifact.project}/{artifact.name}" print(f"Using artifact {full_artifact_name}") artifact_url_prefix = f"https://wandb.ai/{artifact.entity}/{artifact.project}/artifacts/{artifact.type}" artifact_url_suffix = f"{artifact.name.replace(':', '/')}" print(f"View at URL: {artifact_url_prefix}/{artifact_url_suffix}") print(f"Logged by {run.name} -- {run.project}/{run.entity}/{run.id}") print(f"View at URL: {run.url}") def get_checkpoint_metadata(run, checkpoint): config = run.config out = {"config": config} try: ckpt_filename = checkpoint.metadata["original_filename"] out["original_filename"] = ckpt_filename metric_key = checkpoint.metadata["ModelCheckpoint"]["monitor"] metric_score = checkpoint.metadata["score"] out[metric_key] = metric_score except KeyError: pass return out def download_artifact(artifact_path, target_directory): """Downloads the artifact at artifact_path to the target directory.""" if wandb.run is not None: # if we are inside a W&B run, track that we used this artifact artifact = wandb.use_artifact(artifact_path) else: # otherwise, just download the artifact via the API artifact = api.artifact(artifact_path) artifact.download(root=target_directory) return artifact def load_model_from_checkpoint(ckpt_metadata, directory): config = ckpt_metadata["config"] args = argparse.Namespace(**config) _, model = setup_data_and_model_from_args(args) # load LightningModule from checkpoint pth = Path(directory) / MODEL_CHECKPOINT_PATH lit_model = LITMODEL_CLASS.load_from_checkpoint(checkpoint_path=pth, args=args, model=model, strict=False) lit_model.eval() return lit_model def save_model_to_torchscript(model, directory): scripted_model = model.to_torchscript(method="script", file_path=None) path = Path(directory) / STAGED_MODEL_FILENAME torch.jit.save(scripted_model, path) def upload_staged_model(staged_at, from_directory): staged_at.add_file(Path(from_directory) / STAGED_MODEL_FILENAME) wandb.log_artifact(staged_at) def _find_artifact_run(entity, project, type, run, alias): run_name = f"{entity}/{project}/{run}" api_run = api.run(run_name) artifacts = api_run.logged_artifacts() match = [art for art in artifacts if alias in art.aliases and art.type == type] if not match: raise ValueError(f"No artifact with alias {alias} found at {run_name} of type {type}") if len(match) > 1: raise ValueError(f"Multiple artifacts ({len(match)}) with alias {alias} found at {run_name} of type {type}") return f"{entity}/{project}/{match[0].name}" def _find_artifact_project(entity, project, type, alias): project_name = f"{entity}/{project}" api_project = api.project(project, entity=entity) api_artifact_types = api_project.artifacts_types() # loop through all artifact types in this project for artifact_type in api_artifact_types: if artifact_type.name != type: continue # skipping those that don't match type collections = artifact_type.collections() # loop through all artifacts and their versions for collection in collections: versions = collection.versions() for version in versions: if alias in version.aliases: # looking for the first one that matches the alias return f"{project_name}/{version.name}" raise ValueError(f"Artifact with alias {alias} not found in type {type} in {project_name}") raise ValueError(f"Artifact type {type} not found. {project_name} could be private or not exist.") def _get_entity_from(args): entity = args.entity if entity is None: raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.") elif entity == "DEFAULT": entity = DEFAULT_ENTITY return entity def _setup_parser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--fetch", action="store_true", help=f"If provided, check ENTITY/FROM_PROJECT for an artifact with the provided STAGED_MODEL_NAME and download its latest version to {PROD_STAGING_ROOT}/STAGED_MODEL_NAME.", ) parser.add_argument( "--entity", type=str, default=None, help=f"Entity from which to download the checkpoint. Note that checkpoints are always uploaded to the logged-in wandb entity. Pass the value 'DEFAULT' to also download from default entity, which is currently {DEFAULT_ENTITY}.", ) parser.add_argument( "--from_project", type=str, default=DEFAULT_FROM_PROJECT, help=f"Project from which to download the checkpoint. Default is {DEFAULT_FROM_PROJECT}", ) parser.add_argument( "--to_project", type=str, default=DEFAULT_TO_PROJECT, help=f"Project to which to upload the compiled model. Default is {DEFAULT_TO_PROJECT}.", ) parser.add_argument( "--run", type=str, default=None, help=f"Optionally, the name of a run to check for an artifact of type {MODEL_CHECKPOINT_TYPE} that has the provided CKPT_ALIAS. Default is None.", ) parser.add_argument( "--ckpt_alias", type=str, default=BEST_CHECKPOINT_ALIAS, help=f"Alias that identifies which model checkpoint should be staged.The artifact's alias can be set manually or programmatically elsewhere. Default is {BEST_CHECKPOINT_ALIAS!r}.", ) parser.add_argument( "--staged_model_name", type=str, default=DEFAULT_STAGED_MODEL_NAME, help=f"Name to give the staged model artifact. Default is {DEFAULT_STAGED_MODEL_NAME!r}.", ) return parser if __name__ == "__main__": parser = _setup_parser() args = parser.parse_args() main(args) ================================================ FILE: lab07/training/tests/test_memorize_iam.sh ================================================ #!/bin/bash set -uo pipefail set +e # tests whether we can achieve a criterion loss # on a single batch within a certain number of epochs FAILURE=false # constants and CLI args set by aiming for <5 min test on commodity GPU, # including data download step MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash CRITERION="${2:-1.0}" # train on GPU if it's available GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))') python ./training/run_experiment.py \ --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \ --augment_data false --tf_dropout 0.0 \ --gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \ --log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true if [ "$FAILURE" = true ]; then echo "Memorization test failed at loss criterion $CRITERION" exit 1 fi echo "Memorization test passed at loss criterion $CRITERION" exit 0 ================================================ FILE: lab07/training/tests/test_model_development.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false CI="${CI:-false}" if [ "$CI" = false ]; then export WANDB_PROJECT="fsdl-testing-2022" else export WANDB_PROJECT="fsdl-testing-2022-ci" fi echo "training smaller version of real model class on real data" python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \ --limit_train_batches 1 --limit_val_batches 1 --limit_test_batches 1 --num_sanity_val_steps 0 \ --num_workers 1 --wandb || FAILURE=true TRAIN_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//") echo "staging trained model from run $TRAIN_RUN" python training/stage_model.py --entity DEFAULT --run "$TRAIN_RUN" --staged_model_name test-dummy --ckpt_alias latest --to_project "$WANDB_PROJECT" --from_project "$WANDB_PROJECT" || FAILURE=true echo "fetching staged model" python training/stage_model.py --entity DEFAULT --fetch --from_project $WANDB_PROJECT --staged_model_name test-dummy || FAILURE=true STAGE_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//") if [ "$FAILURE" = true ]; then echo "Model development test failed" echo "cleaning up local files" rm -rf text_recognizer/artifacts/test-dummy echo "leaving remote files in place" exit 1 fi echo "cleaning up local and remote files" rm -rf text_recognizer/artifacts/test-dummy python training/cleanup_artifacts.py --entity DEFAULT --project "$WANDB_PROJECT" \ --run_ids "$TRAIN_RUN" "$STAGE_RUN" --all -v # note: if $TRAIN_RUN and $STAGE_RUN are not set, this will fail. # that's good because it avoids all artifacts from the project being deleted due to the --all. echo "Model development test passed" exit 0 ================================================ FILE: lab07/training/tests/test_run_experiment.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false echo "running full loop test with CNN on fake data" python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true echo "running fast_dev_run test of real model class on real data" python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \ --fast_dev_run --num_sanity_val_steps 0 \ --num_workers 1 || FAILURE=true if [ "$FAILURE" = true ]; then echo "Test for run_experiment.py failed" exit 1 fi echo "Tests for run_experiment.py passed" exit 0 ================================================ FILE: lab07/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: lab08/.flake8 ================================================ [flake8] select = ANN,B,B9,BLK,C,D,E,F,I,S,W # only check selected error codes max-complexity = 12 # C9 - flake8 McCabe Complexity checker -- threshold max-line-length = 120 # E501 - flake8 -- line length too long, actually handled by black extend-ignore = # E W - flake8 PEP style check E203,E402,E501,W503, # whitespace, import, line length, binary operator line breaks # S - flake8-bandit safety check S101,S113,S311,S105, # assert removed in bytecode, no request timeout, pRNG not secure, hardcoded password # ANN - flake8-annotations type annotation check ANN,ANN002,ANN003,ANN101,ANN102,ANN202, # ignore all for now, but always ignore some # D1 - flake8-docstrings docstring style check D100,D102,D103,D104,D105, # missing docstrings # D2 D4 - flake8-docstrings docstring style check D200,D205,D400,D401, # whitespace issues and first line content # DAR - flake8-darglint docstring correctness check DAR103, # mismatched or missing type in docstring application-import-names = app_gradio,text_recognizer,tests,training # flake8-import-order: which names are first party? import-order-style = google # flake8-import-order: which import order style guide do we use? docstring-convention = numpy # flake8-docstrings: which docstring style guide do we use? strictness = short # darglint: how "strict" are we with docstring completeness? docstring-style = numpy # darglint: which docstring style guide do we use? suppress-none-returning = true # flake8-annotations: do we allow un-annotated Nones in returns? mypy-init-return = true # flake8-annotations: do we allow init to have no return annotation? per-file-ignores = # list of case-by-case ignores, see files for details */__init__.py:F401,I */data/*.py:DAR data/*.py:F,I *text_recognizer/util.py:DAR101,F401 *training/run_experiment.py:I202 *app_gradio/app.py:I202 ================================================ FILE: lab08/.github/workflows/pre-commit.yml ================================================ name: pre-commit on: pull_request: push: # allows this Action to be triggered manually workflow_dispatch: jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 with: python-version: '3.10' - uses: pre-commit/action@v3.0.0 ================================================ FILE: lab08/.pre-commit-config.yaml ================================================ repos: # a set of useful Python-based pre-commit hooks - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: # list of definitions and supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace # removes any whitespace at the ends of lines - id: check-toml # check toml syntax by loading all toml files - id: check-yaml # check yaml syntax by loading all yaml files - id: check-json # check-json syntax by loading all json files - id: check-merge-conflict # check for files with merge conflict strings args: ['--assume-in-merge'] # and run this check even when not explicitly in a merge - id: check-added-large-files # check that no "large" files have been added args: ['--maxkb=10240'] # where large means 10MB+, as in Hugging Face's git server - id: debug-statements # check for python debug statements (import pdb, breakpoint, etc.) - id: detect-private-key # checks for private keys (BEGIN X PRIVATE KEY, etc.) # black python autoformatting - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black # additional configuration of black in pyproject.toml # flake8 python linter with all the fixins - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 hooks: - id: flake8 exclude: (lab01|lab02|lab03|lab04|lab06|lab07|lab08) additional_dependencies: [ flake8-bandit, flake8-bugbear, flake8-docstrings, flake8-import-order, darglint, mypy, pycodestyle, pydocstyle] args: ["--config", ".flake8"] # additional configuration of flake8 and extensions in .flake8 # shellcheck-py for linting shell files - repo: https://github.com/shellcheck-py/shellcheck-py rev: v0.8.0.4 hooks: - id: shellcheck ================================================ FILE: lab08/api_serverless/Dockerfile ================================================ # Starting from an official AWS image # Keep any dependencies and versions in this file aligned with the environment.yml and Makefile FROM public.ecr.aws/lambda/python:3.10 # Install Python dependencies COPY requirements/prod.txt ./requirements.txt RUN pip install --upgrade pip==23.1.2 RUN pip install -r requirements.txt # Copy only the relevant directories and files # note that we use a .dockerignore file to avoid copying logs etc. COPY text_recognizer/ ./text_recognizer COPY api_serverless/api.py ./api.py CMD ["api.handler"] ================================================ FILE: lab08/api_serverless/__init__.py ================================================ """Cloud function-backed API for paragraph recognition.""" ================================================ FILE: lab08/api_serverless/api.py ================================================ """AWS Lambda function serving text_recognizer predictions.""" import json from PIL import ImageStat from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer import text_recognizer.util as util model = ParagraphTextRecognizer() def handler(event, _context): """Provide main prediction API.""" print("INFO loading image") image = _load_image(event) if image is None: return {"statusCode": 400, "message": "neither image_url nor image found in event"} print("INFO image loaded") print("INFO starting inference") pred = model.predict(image) print("INFO inference complete") image_stat = ImageStat.Stat(image) print("METRIC image_mean_intensity {}".format(image_stat.mean[0])) print("METRIC image_area {}".format(image.size[0] * image.size[1])) print("METRIC pred_length {}".format(len(pred))) print("INFO pred {}".format(pred)) return {"pred": str(pred)} def _load_image(event): event = _from_string(event) event = _from_string(event.get("body", event)) image_url = event.get("image_url") if image_url is not None: print("INFO url {}".format(image_url)) return util.read_image_pil(image_url, grayscale=True) else: image = event.get("image") if image is not None: print("INFO reading image from event") return util.read_b64_image(image, grayscale=True) else: return None def _from_string(event): if isinstance(event, str): return json.loads(event) else: return event ================================================ FILE: lab08/app_gradio/Dockerfile ================================================ # The "buster" flavor of the official docker Python image is based on Debian and includes common packages. # Keep any dependencies and versions in this file aligned with the environment.yml and Makefile FROM python:3.10-buster # Create the working directory # set -x prints commands and set -e causes us to stop on errors RUN set -ex && mkdir /repo WORKDIR /repo # Install Python dependencies COPY requirements/prod.txt ./requirements.txt RUN pip install --upgrade pip==23.1.2 RUN pip install -r requirements.txt ENV PYTHONPATH ".:" # Copy only the relevant directories # note that we use a .dockerignore file to avoid copying logs etc. COPY text_recognizer/ ./text_recognizer COPY app_gradio/ ./app_gradio # Use docker run -it --rm -p $PORT:11717 to run the web server and listen on host $PORT # add --help to see help for the Python script ENTRYPOINT ["python3", "app_gradio/app.py", "--port", "11717"] ================================================ FILE: lab08/app_gradio/README.md ================================================ ## Full-Paragraph Optical Character Recognition For more on how this application works, [check out the GitHub repo](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022). ### Flagging If the model outputs in the top-right are wrong in some way, let us know by clicking the "flagging" buttons underneath. We'll analyze the results with [Gantry](https://gantry.io/blog/introducing-gantry/) and use them to improve the model! ================================================ FILE: lab08/app_gradio/__init__.py ================================================ ================================================ FILE: lab08/app_gradio/app.py ================================================ """Provide an image of handwritten text and get back out a string!""" import argparse import json import logging import os from pathlib import Path from typing import Callable import warnings import gradio as gr from PIL import ImageStat from PIL.Image import Image import requests from app_gradio.flagging import GantryImageToTextLogger, get_api_key from app_gradio.s3_util import make_unique_bucket_name from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer import text_recognizer.util as util os.environ["CUDA_VISIBLE_DEVICES"] = "" # do not use GPU logging.basicConfig(level=logging.INFO) DEFAULT_APPLICATION_NAME = "fsdl-text-recognizer" APP_DIR = Path(__file__).resolve().parent # what is the directory for this application? FAVICON = APP_DIR / "1f95e.png" # path to a small image for display in browser tab and social media README = APP_DIR / "README.md" # path to an app readme file in HTML/markdown DEFAULT_PORT = 11700 def main(args): predictor = PredictorBackend(url=args.model_url) frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application) frontend.launch( server_name="0.0.0.0", # make server accessible, binding all interfaces # noqa: S104 server_port=args.port, # set a port to bind to, failing if unavailable share=True, # should we create a (temporary) public link on https://gradio.app? favicon_path=FAVICON, # what icon should we display in the address bar? ) def make_frontend( fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fsdl-text-recognizer" ): """Creates a gradio.Interface frontend for an image to text function.""" examples_dir = Path("text_recognizer") / "tests" / "support" / "paragraphs" example_fnames = [elem for elem in os.listdir(examples_dir) if elem.endswith(".png")] example_paths = [examples_dir / fname for fname in example_fnames] examples = [[str(path)] for path in example_paths] allow_flagging = "never" if flagging: allow_flagging = "manual" api_key = get_api_key() if gantry and api_key: # if we're logging user feedback to Gantry and we have an API key allow_flagging = "manual" # turn on Gradio flagging features # callback for logging input images, output text, and feedback to Gantry flagging_callback = GantryImageToTextLogger(application=app_name, api_key=api_key) # that sends images to S3 flagging_dir = make_unique_bucket_name(prefix=app_name, seed=api_key) else: # otherwise, log to a local CSV file if gantry and api_key is None: warnings.warn("No Gantry API key found, logging to local directory instead.", stacklevel=1) flagging_callback = gr.CSVLogger() flagging_dir = "flagged" else: flagging_callback, flagging_dir = None, None readme = _load_readme(with_logging=allow_flagging == "manual") # build a basic browser interface to a Python function frontend = gr.Interface( fn=fn, # which Python function are we interacting with? outputs=gr.components.Textbox(), # what output widgets does it need? the default text widget # what input widgets does it need? we configure an image widget inputs=gr.components.Image(type="pil", label="Handwritten Text"), title="📝 Text Recognizer", # what should we display at the top of the page? thumbnail=FAVICON, # what should we display when the link is shared, e.g. on social media? description=__doc__, # what should we display just above the interface? article=readme, # what long-form content should we display below the interface? examples=examples, # which potential inputs should we provide? cache_examples=False, # should we cache those inputs for faster inference? slows down start allow_flagging=allow_flagging, # should we show users the option to "flag" outputs? flagging_options=["incorrect", "offensive", "other"], # what options do users have for feedback? flagging_callback=flagging_callback, flagging_dir=flagging_dir, ) return frontend class PredictorBackend: """Interface to a backend that serves predictions. To communicate with a backend accessible via a URL, provide the url kwarg. Otherwise, runs a predictor locally. """ def __init__(self, url=None): if url is not None: self.url = url self._predict = self._predict_from_endpoint else: model = ParagraphTextRecognizer() self._predict = model.predict def run(self, image): pred, metrics = self._predict_with_metrics(image) self._log_inference(pred, metrics) return pred def _predict_with_metrics(self, image): pred = self._predict(image) stats = ImageStat.Stat(image) metrics = { "image_mean_intensity": stats.mean, "image_median": stats.median, "image_extrema": stats.extrema, "image_area": image.size[0] * image.size[1], "pred_length": len(pred), } return pred, metrics def _predict_from_endpoint(self, image): """Send an image to an endpoint that accepts JSON and return the predicted text. The endpoint should expect a base64 representation of the image, encoded as a string, under the key "image". It should return the predicted text under the key "pred". Parameters ---------- image A PIL image of handwritten text to be converted into a string. Returns ------- pred A string containing the predictor's guess of the text in the image. """ encoded_image = util.encode_b64_image(image) headers = {"Content-type": "application/json"} payload = json.dumps({"image": "data:image/png;base64," + encoded_image}) response = requests.post(self.url, data=payload, headers=headers) pred = response.json()["pred"] return pred def _log_inference(self, pred, metrics): for key, value in metrics.items(): logging.info(f"METRIC {key} {value}") logging.info(f"PRED >begin\n{pred}\nPRED >end") def _load_readme(with_logging=False): with open(README) as f: lines = f.readlines() if not with_logging: lines = lines[: lines.index("\n")] readme = "".join(lines) return readme def _make_parser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--model_url", default=None, type=str, help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.", ) parser.add_argument( "--port", default=DEFAULT_PORT, type=int, help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.", ) parser.add_argument( "--flagging", action="store_true", help="Pass this flag to allow users to 'flag' model behavior and provide feedback.", ) parser.add_argument( "--gantry", action="store_true", help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.", ) parser.add_argument( "--application", default=DEFAULT_APPLICATION_NAME, type=str, help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.", ) return parser if __name__ == "__main__": parser = _make_parser() args = parser.parse_args() main(args) ================================================ FILE: lab08/app_gradio/flagging.py ================================================ import os from typing import List, Optional, Union import gantry import gradio as gr from gradio.components import Component from smart_open import open from app_gradio import s3_util from text_recognizer.util import read_b64_string class GantryImageToTextLogger(gr.FlaggingCallback): """A FlaggingCallback that logs flagged image-to-text data to Gantry via S3.""" def __init__(self, application: str, version: Union[int, str, None] = None, api_key: Optional[str] = None): """Logs image-to-text data that was flagged in Gradio to Gantry. Images are logged to Amazon Web Services' Simple Storage Service (S3). The flagging_dir provided to the Gradio interface is used to set the name of the bucket on S3 into which images are logged. See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK for Python, boto3: https://realpython.com/python-boto3-aws-s3/ See https://gradio.app/docs/#flagging for details on how flagging data is handled by Gradio. See https://docs.gantry.io for information about logging data to Gantry. Parameters ---------- application The name of the application on Gantry to which flagged data should be uploaded. Gantry validates and monitors data per application. version The schema version to use during validation by Gantry. If not provided, Gantry will use the latest version. A new version will be created if the provided version does not exist yet. api_key Optionally, provide your Gantry API key here. Provided for convenience when testing and developing locally or in notebooks. The API key can alternatively be provided via the GANTRY_API_KEY environment variable. """ self.application = application self.version = version gantry.init(api_key=api_key) def setup(self, components: List[Component], flagging_dir: str): """Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket.""" self._counter = 0 self.bucket = s3_util.get_or_create_bucket(flagging_dir) s3_util.enable_bucket_versioning(self.bucket) s3_util.add_access_policy(self.bucket) self.image_component_idx, self.text_component_idx = self._find_image_and_text_components(components) def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int: """Sends flagged outputs and feedback to Gantry and image inputs to S3.""" image = flag_data[self.image_component_idx] text = flag_data[self.text_component_idx] feedback = {"flag": flag_option} if username is not None: feedback["user"] = username data_type, image_buffer = read_b64_string(image, return_data_type=True) image_url = self._to_s3(image_buffer.read(), filetype=data_type) self._to_gantry(image_url, text, feedback) self._counter += 1 return self._counter def _to_gantry(self, input_image_url, output_text, feedback): inputs = {"image": input_image_url} outputs = {"output_text": output_text} gantry.log_record(self.application, self.version, inputs=inputs, outputs=outputs, feedback=feedback) def _to_s3(self, image_bytes, key=None, filetype=None): if key is None: key = s3_util.make_key(image_bytes, filetype=filetype) s3_uri = s3_util.get_uri_of(self.bucket, key) with open(s3_uri, "wb") as s3_object: s3_object.write(image_bytes) return s3_uri def _find_image_and_text_components(self, components: List[Component]): image_component_idx, text_component_idx = None, None for idx, component in enumerate(components): if isinstance(component, (gr.inputs.Image, gr.components.Image)): image_component_idx = idx elif isinstance(component, (gr.templates.Text, gr.components.Textbox)): text_component_idx = idx if image_component_idx is None: raise RuntimeError(f"No image input found in gradio interface with components {components}") elif text_component_idx is None: raise RuntimeError(f"No text output found in gradio interface with components {components}") return image_component_idx, text_component_idx def get_api_key() -> Optional[str]: """Convenience method for fetching the Gantry API key.""" api_key = os.environ.get("GANTRY_API_KEY") return api_key ================================================ FILE: lab08/app_gradio/s3_util.py ================================================ import hashlib import json import boto3 import botocore S3_URL_FORMAT = "https://{bucket}.s3.{region}.amazonaws.com/{key}" S3_URI_FORMAT = "s3://{bucket}/{key}" s3 = boto3.resource("s3") def get_or_create_bucket(name): """Gets an S3 bucket with boto3 or creates it if it doesn't exist.""" try: # try to create a bucket name, response = _create_bucket(name) except botocore.exceptions.ClientError as err: # error handling from https://github.com/boto/boto3/issues/1195#issuecomment-495842252 status = err.response["ResponseMetadata"]["HTTPStatusCode"] # status codes identify particular errors if status == 409: # if the bucket exists already, pass # we don't need to make it -- we presume we have the right permissions else: raise err bucket = s3.Bucket(name) return bucket def _create_bucket(name): """Creates a bucket with the provided name.""" session = boto3.session.Session() # sessions hold on to credentials and config current_region = session.region_name # so we can pull the default region bucket_config = {"LocationConstraint": current_region} # and apply it to the bucket bucket_response = s3.create_bucket(Bucket=name, CreateBucketConfiguration=bucket_config) return name, bucket_response def make_key(fileobj, filetype=None): """Creates a unique key for the fileobj and optionally append the filetype.""" identifier = make_identifier(fileobj) if filetype is None: return identifier else: return identifier + "." + filetype def make_unique_bucket_name(prefix, seed): """Creates a unique bucket name from a prefix and a seed.""" name = hashlib.sha256(seed.encode("utf-8")).hexdigest()[:10] return prefix + "-" + name def get_url_of(bucket, key=None): """Returns the url of a bucket and optionally of an object in that bucket.""" if not isinstance(bucket, str): bucket = bucket.name region = _get_region(bucket) key = key or "" url = _format_url(bucket, region, key) return url def get_uri_of(bucket, key=None): """Returns the s3:// uri of a bucket and optionally of an object in that bucket.""" if not isinstance(bucket, str): bucket = bucket.name key = key or "" uri = _format_uri(bucket, key) return uri def enable_bucket_versioning(bucket): """Turns on versioning for bucket contents, which avoids deletion.""" if not isinstance(bucket, str): bucket = bucket.name bucket_versioning = s3.BucketVersioning(bucket) return bucket_versioning.enable() def add_access_policy(bucket): """Adds a policy to our bucket that allows the Gantry app to access data.""" access_policy = json.dumps(_get_policy(bucket.name)) s3.meta.client.put_bucket_policy(Bucket=bucket.name, Policy=access_policy) def _get_policy(bucket_name): """Returns a bucket policy allowing Gantry app access as a JSON-compatible dictionary.""" return { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Principal": { "AWS": [ "arn:aws:iam::848836713690:root", "arn:aws:iam::339325199688:root", "arn:aws:iam::665957668247:root", ] }, "Action": ["s3:GetObject", "s3:GetObjectVersion"], "Resource": f"arn:aws:s3:::{bucket_name}/*", }, { "Effect": "Allow", "Principal": { "AWS": [ "arn:aws:iam::848836713690:root", "arn:aws:iam::339325199688:root", "arn:aws:iam::665957668247:root", ] }, "Action": "s3:ListBucketVersions", "Resource": f"arn:aws:s3:::{bucket_name}", }, ], } def make_identifier(byte_data): """Create a unique identifier for a collection of bytes via hashing.""" # feed them to hashing algo -- security is not critical here, so we use SHA-1 hashed_data = hashlib.sha1(byte_data) # noqa: S3 identifier = hashed_data.hexdigest() # turn it into hexdecimal return identifier def _get_region(bucket): """Determine the region of an s3 bucket.""" if not isinstance(bucket, str): bucket = bucket.name s3_client = boto3.client("s3") bucket_location_response = s3_client.get_bucket_location(Bucket=bucket) bucket_location = bucket_location_response["LocationConstraint"] return bucket_location def _format_url(bucket_name, region, key=None): key = key or "" url = S3_URL_FORMAT.format(bucket=bucket_name, region=region, key=key) return url def _format_uri(bucket_name, key=None): key = key or "" uri = S3_URI_FORMAT.format(bucket=bucket_name, key=key) return uri ================================================ FILE: lab08/app_gradio/tests/test_app.py ================================================ import json import os import requests from app_gradio import app from text_recognizer import util os.environ["CUDA_VISIBLE_DEVICES"] = "" TEST_IMAGE = "text_recognizer/tests/support/paragraphs/a01-077.png" def test_local_run(): """A quick test to make sure we can build the app and ping the API locally.""" backend = app.PredictorBackend() frontend = app.make_frontend(fn=backend.run) # run the UI without blocking frontend.launch(share=False, prevent_thread_lock=True) local_url = frontend.local_url get_response = requests.get(local_url) assert get_response.status_code == 200, get_response.content image_b64 = util.encode_b64_image(util.read_image_pil(TEST_IMAGE)) local_api = f"{local_url}api/predict" headers = {"Content-Type": "application/json"} payload = json.dumps({"data": ["data:image/png;base64," + image_b64]}) post_response = requests.post(local_api, data=payload, headers=headers) assert post_response.status_code == 200, post_response.content ================================================ FILE: lab08/notebooks/lab01_pytorch.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 01: Deep Neural Networks in PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How to write a basic neural network from scratch in PyTorch\n", "- How the submodules of `torch`, like `torch.nn` and `torch.utils.data`, make writing performant neural network training and inference code easier" ] }, { "cell_type": "markdown", "metadata": { "id": "6c7bFQ20LbLB" }, "source": [ "At its core, PyTorch is a library for\n", "- doing math on arrays\n", "- with automatic calculation of gradients\n", "- that is easy to accelerate with GPUs and distribute over nodes.\n", "\n", "Much of the time,\n", "we work at a remove from the core features of PyTorch,\n", "using abstractions from `torch.nn`\n", "or from frameworks on top of PyTorch.\n", "\n", "This tutorial builds those abstractions up\n", "from core PyTorch,\n", "showing how to go from basic iterated\n", "gradient computation and application\n", "to a solid training and validation loop.\n", "It is adapted from the PyTorch tutorial\n", "[What is `torch.nn` really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html).\n", "\n", "We assume familiarity with the fundamentals of ML and DNNs here,\n", "like gradient-based optimization and statistical learning.\n", "For refreshing on those, we recommend\n", "[3Blue1Brown's videos](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown)\n", "or\n", "[the NYU course on deep learning by Le Cun and Canziani](https://cds.nyu.edu/deep-learning/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 1\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "6wJ8r7BTPB-t" }, "source": [ "# Getting data and making `Tensor`s" ] }, { "cell_type": "markdown", "metadata": { "id": "MpRyqPPYie-F" }, "source": [ "Before we can build a model,\n", "we need data.\n", "\n", "The code below uses the Python standard library to download the\n", "[MNIST dataset of handwritten digits](https://en.wikipedia.org/wiki/MNIST_database)\n", "from the internet.\n", "\n", "The data used to train state-of-the-art models these days\n", "is generally too large to be stored on the disk of any single machine\n", "(to say nothing of the RAM!),\n", "so fetching data over a network is a common first step in model training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CsokTZTMJ3x6" }, "outputs": [], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "def download_mnist(path):\n", " url = \"https://github.com/pytorch/tutorials/raw/main/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", "\n", " if not (path / filename).exists():\n", " content = requests.get(url + filename).content\n", " (path / filename).open(\"wb\").write(content)\n", "\n", " return path / filename\n", "\n", "\n", "data_path = Path(\"data\") if Path(\"data\").exists() else Path(\"../data\")\n", "path = data_path / \"downloaded\" / \"vector-mnist\"\n", "path.mkdir(parents=True, exist_ok=True)\n", "\n", "datafile = download_mnist(path)" ] }, { "cell_type": "markdown", "metadata": { "id": "-S0es1DujOyr" }, "source": [ "Larger data consumes more resources --\n", "when reading, writing, and sending over the network --\n", "so the dataset is compressed\n", "(`.gz` extension).\n", "\n", "Each piece of the dataset\n", "(training and validation inputs and outputs)\n", "is a single Python object\n", "(specifically, an array).\n", "We can persist Python objects to disk\n", "(also known as \"serialization\")\n", "and load them back in\n", "(also known as \"deserialization\")\n", "using the `pickle` library\n", "(`.pkl` extension)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QZosCF1xJ3x7" }, "outputs": [], "source": [ "import gzip\n", "import pickle\n", "\n", "\n", "def read_mnist(path):\n", " with gzip.open(path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", " return x_train, y_train, x_valid, y_valid\n", "\n", "x_train, y_train, x_valid, y_valid = read_mnist(datafile)" ] }, { "cell_type": "markdown", "metadata": { "id": "KIYUbKgmknDf" }, "source": [ "PyTorch provides its own array type,\n", "the `torch.Tensor`.\n", "The cell below converts our arrays into `torch.Tensor`s.\n", "\n", "Very roughly speaking, a \"tensor\" in ML\n", "just means the same thing as an\n", "\"array\" elsewhere in computer science.\n", "Terminology is different in\n", "[physics](https://physics.stackexchange.com/a/270445),\n", "[mathematics](https://en.wikipedia.org/wiki/Tensor#Using_tensor_products),\n", "and [computing](https://www.kdnuggets.com/2018/05/wtf-tensor.html),\n", "but here the term \"tensor\" is intended to connote\n", "an array that might have more than two dimensions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ea5d3Ggfkhea" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "D0AMKLxGkmc_" }, "source": [ "Tensors are defined by their contents:\n", "they are big rectangular blocks of numbers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yPvh8c_pkl5A" }, "outputs": [], "source": [ "print(x_train, y_train, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4UOYvwjFqdzu" }, "source": [ "Accessing the contents of `Tensor`s is called \"indexing\",\n", "and uses the same syntax as general Python indexing.\n", "It always returns a new `Tensor`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9zGDAPXVqdCm" }, "outputs": [], "source": [ "y_train[0], x_train[0, ::2]" ] }, { "cell_type": "markdown", "metadata": { "id": "QhJcOr8TmgmQ" }, "source": [ "PyTorch, like many libraries for high-performance array math,\n", "allows us to quickly and easily access metadata about our tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "4ENirftAnIVM" }, "source": [ "The most important pieces of metadata about a `Tensor`,\n", "or any array, are its _dimension_\n", "and its _shape_.\n", "\n", "The dimension specifies how many indices you need to get a number\n", "out of an array." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhaN6qW0nA5t" }, "outputs": [], "source": [ "x_train.ndim, y_train.ndim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pYEk13yoGgz" }, "outputs": [], "source": [ "x_train[0, 0], y_train[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "rv2WWNcHkEeS" }, "source": [ "For a one-dimensional `Tensor` like `y_train`, the shape tells you how many entries it has.\n", "For a two-dimensional `Tensor` like `x_train`, the shape tells you how many rows and columns it has." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yZ6j-IGPJ3x7" }, "outputs": [], "source": [ "n, c = x_train.shape\n", "print(x_train.shape)\n", "print(y_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "H-HFN9WJo6FK" }, "source": [ "This metadata serves a similar purpose for `Tensor`s\n", "as type metadata serves for other objects in Python\n", "(and other programming languages).\n", "\n", "That is, types tell us whether an object is an acceptable\n", "input for or output of a function.\n", "Many functions on `Tensor`s, like indexing,\n", "matrix multiplication,\n", "can only accept as input `Tensor`s of a certain shape and dimension\n", "and will return as output `Tensor`s of a certain shape and dimension.\n", "\n", "So printing `ndim` and `shape` to track\n", "what's happening to `Tensor`s during a computation\n", "is an important piece of the debugging toolkit!" ] }, { "cell_type": "markdown", "metadata": { "id": "wCjuWKKNrWGM" }, "source": [ "We won't spend much time here on writing raw array math code in PyTorch,\n", "nor will we spend much time on how PyTorch works.\n", "\n", "> If you'd like to get better at writing PyTorch code,\n", "try out\n", "[these \"Tensor Puzzles\" by Sasha Rush](https://github.com/srush/Tensor-Puzzles).\n", "We wrote a bit about what these puzzles reveal about programming\n", "with arrays [here](https://twitter.com/charles_irl/status/1517991568266776577?s=20&t=i9cZJer0RPI2lzPIiCF_kQ).\n", "\n", "> If you'd like to get a better understanging of the internals\n", "of PyTorch, check out\n", "[this blog post by Edward Yang](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "As we'll see below,\n", "`torch.nn` provides most of what we need\n", "for building deep learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "Li5e_jiJpLSI" }, "source": [ "The `Tensor`s inside of the `x_train` `Tensor`\n", "aren't just any old blocks of numbers:\n", "they're images of handwritten digits.\n", "The `y_train` `Tensor` contains the identities of those digits.\n", "\n", "Let's take a look at a random example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4VsHk6xNJ3x8" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "import random\n", "\n", "import wandb # just for some convenience methods that convert tensors to human-friendly datatypes\n", "\n", "import text_recognizer.metadata.mnist as metadata # metadata module holds metadata separate from data\n", "\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx]\n", "\n", "print(y_train[idx]) # the label of the image\n", "wandb.Image(example.reshape(*metadata.DIMS)).image # the image itself" ] }, { "cell_type": "markdown", "metadata": { "id": "PC3pwoJ9s-ts" }, "source": [ "We want to build a deep network that can take in an image\n", "and return the number that's in the image.\n", "\n", "We'll build that network\n", "by fitting it to `x_train` and `y_train`.\n", "\n", "We'll first do our fitting with just basic `torch` components and Python,\n", "then we'll add in other `torch` gadgets and goodies\n", "until we have a more realistic neural network fitting loop.\n", "\n", "Later in the labs,\n", "we'll see how to even more quickly build\n", "performant, robust fitting loops\n", "that have even more features\n", "by using libraries built on top of PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "DTLdqCIGJ3x6" }, "source": [ "# Building a DNN using only `torch.Tensor` methods and Python" ] }, { "cell_type": "markdown", "metadata": { "id": "8D8Xuh2xui3o" }, "source": [ "One of the really great features of PyTorch\n", "is that writing code in PyTorch feels\n", "very similar to writing other code in Python --\n", "unlike other deep learning frameworks\n", "that can sometimes feel like their own language\n", "or programming paradigm.\n", "\n", "This fact can sometimes be obscured\n", "when you're using lots of library code,\n", "so we start off by just using `Tensor`s and the Python standard library." ] }, { "cell_type": "markdown", "metadata": { "id": "tOV0bxySJ3x9" }, "source": [ "## Defining the model" ] }, { "cell_type": "markdown", "metadata": { "id": "ZLH_zUWkw3W0" }, "source": [ "We'll make the simplest possible neural network:\n", "a single layer that performs matrix multiplication,\n", "and adds a vector of biases.\n", "\n", "We'll need values for the entries of the matrix,\n", "which we generate randomly.\n", "\n", "We also need to tell PyTorch that we'll\n", "be taking gradients with respect to\n", "these `Tensor`s later, so we use `requires_grad`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1c21c8XQJ3x-" }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "\n", "\n", "weights = torch.randn(784, 10) / math.sqrt(784)\n", "weights.requires_grad_()\n", "bias = torch.zeros(10, requires_grad=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "GZC8A01sytm2" }, "source": [ "We can combine our beloved Python operators,\n", "like `+` and `*` and `@` and indexing,\n", "to define the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8Eoymwooyq0-" }, "outputs": [], "source": [ "def linear(x: torch.Tensor) -> torch.Tensor:\n", " return x @ weights + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "5tIRHR_HxeZf" }, "source": [ "We need to normalize our model's outputs with a `softmax`\n", "to get our model to output something we can use\n", "as a probability distribution --\n", "the probability that the network assigns to each label for the image.\n", "\n", "For that, we'll need some `torch` math functions,\n", "like `torch.sum` and `torch.exp`.\n", "\n", "We compute the logarithm of that softmax value\n", "in part for numerical stability reasons\n", "and in part because\n", "[it is more natural to work with the logarithms of probabilities](https://youtu.be/LBemXHm_Ops?t=1071)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WuZRGSr4J3x-" }, "outputs": [], "source": [ "def log_softmax(x: torch.Tensor) -> torch.Tensor:\n", " return x - torch.log(torch.sum(torch.exp(x), axis=1))[:, None]\n", "\n", "def model(xb: torch.Tensor) -> torch.Tensor:\n", " return log_softmax(linear(xb))" ] }, { "cell_type": "markdown", "metadata": { "id": "-pBI4pOM011q" }, "source": [ "Typically, we split our dataset up into smaller \"batches\" of data\n", "and apply our model to one batch at a time.\n", "\n", "Since our dataset is just a `Tensor`,\n", "we can pull that off just with indexing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pXsHak23J3x_" }, "outputs": [], "source": [ "bs = 64 # batch size\n", "\n", "xb = x_train[0:bs] # a batch of inputs\n", "outs = model(xb) # outputs on that batch\n", "\n", "print(outs[0], outs.shape) # outputs on the first element of the batch" ] }, { "cell_type": "markdown", "metadata": { "id": "VPrG9x1DJ3x_" }, "source": [ "## Defining the loss and metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "zEwPJmgZ1HIp" }, "source": [ "Our model produces outputs, but they are mostly wrong,\n", "since we set the weights randomly.\n", "\n", "How can we quantify just how wrong our model is,\n", "so that we can make it better?" ] }, { "cell_type": "markdown", "metadata": { "id": "JY-2QZEu1Xc7" }, "source": [ "We want to compare the outputs and the target labels,\n", "but the model outputs a probability distribution,\n", "and the labels are just numbers.\n", "\n", "We can take the label that had the highest probability\n", "(the index of the largest output for each input,\n", "aka the `argmax` over `dim`ension `1`)\n", "and treat that as the model's prediction\n", "for the digit in the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_sHmDw_cJ3yC" }, "outputs": [], "source": [ "def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:\n", " preds = torch.argmax(out, dim=1)\n", " return (preds == yb).float().mean()" ] }, { "cell_type": "markdown", "metadata": { "id": "PfrDJb2EF_uz" }, "source": [ "If we run that function on our model's `out`put`s`,\n", "we can confirm that the random model isn't doing well --\n", "we expect to see that something around one in ten predictions are correct." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8l3aRMNaJ3yD" }, "outputs": [], "source": [ "yb = y_train[0:bs]\n", "\n", "acc = accuracy(outs, yb)\n", "\n", "print(acc)" ] }, { "cell_type": "markdown", "metadata": { "id": "fxRfO1HQ3VYs" }, "source": [ "We can calculate how good our network is doing,\n", "so are we ready to use optimization to make it do better?\n", "\n", "Not yet!\n", "To train neural networks, we use gradients\n", "(aka derivatives).\n", "So all of the functions we use need to be differentiable --\n", "in particular they need to change smoothly so that a small change in input\n", "can only cause a small change in output.\n", "\n", "Our `argmax` breaks that rule\n", "(if the values at index `0` and index `N` are really close together,\n", "a tiny change can change the output by `N`)\n", "so we can't use it.\n", "\n", "If we try to run our `backward`s pass to get a gradient,\n", "we get a `RuntimeError`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g5AnK4md4kxv" }, "outputs": [], "source": [ "try:\n", " acc.backward()\n", "except RuntimeError as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": { "id": "HJ4WWHHJ460I" }, "source": [ "So we'll need something else:\n", "a differentiable function that gets smaller when\n", "our model gets better, aka a `loss`.\n", "\n", "The typical choice is to maximize the\n", "probability the network assigns to the correct label.\n", "\n", "We could try doing that directly,\n", "but more generally,\n", "we want the model's output probability distribution\n", "to match what we provide it -- \n", "here, we claim we're 100% certain in every label,\n", "but in general we allow for uncertainty.\n", "We quantify that match with the\n", "[cross entropy](https://charlesfrye.github.io/stats/2017/11/09/the-surprise-game.html).\n", "\n", "Cross entropies\n", "[give rise to most loss functions](https://youtu.be/LBemXHm_Ops?t=1316),\n", "including more familiar functions like the\n", "mean squared error and the mean absolute error.\n", "\n", "We can calculate it directly from the outputs and target labels\n", "using some cute tricks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-k20rW_rJ3yA" }, "outputs": [], "source": [ "def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n", " return -output[range(target.shape[0]), target].mean()\n", "\n", "loss_func = cross_entropy" ] }, { "cell_type": "markdown", "metadata": { "id": "YZa1DSGN7zPK" }, "source": [ "With random guessing on a dataset with 10 equally likely options,\n", "we expect our loss value to be close to the negative logarithm of 1/10:\n", "the amount of entropy in a uniformly random digit." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1bKRJ90MJ3yB" }, "outputs": [], "source": [ "print(loss_func(outs, yb), -torch.log(torch.tensor(1 / 10)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hTgFTdVgAGJW" }, "source": [ "Now we can call `.backward` without PyTorch complaining:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LH_ZpY0_e_6" }, "outputs": [], "source": [ "loss = loss_func(outs, yb)\n", "\n", "loss.backward()" ] }, { "cell_type": "markdown", "metadata": { "id": "ji0FA3dDACUk" }, "source": [ "But wait, where are the gradients?\n", "They weren't returned by `loss` above,\n", "so where could they be?\n", "\n", "They've been stored in the `.grad` attribute\n", "of the parameters of our model,\n", "`weights` and `bias`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zgtyyhp__s8a" }, "outputs": [], "source": [ "bias.grad" ] }, { "cell_type": "markdown", "metadata": { "id": "dWTYno0JJ3yD" }, "source": [ "## Defining and running the fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "TTR2Qo9F8ZLQ" }, "source": [ "We now have all the ingredients we need to fit a neural network to data:\n", "- data (`x_train`, `y_train`)\n", "- a network architecture with parameters (`model`, `weights`, and `bias`)\n", "- a `loss_func`tion to optimize (`cross_entropy`) that supports `.backward` computation of gradients\n", "\n", "We can put them together into a training loop\n", "just using normal Python features,\n", "like `for` loops, indexing, and function calls:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzNZVEiVJ3yE" }, "outputs": [], "source": [ "lr = 0.5 # learning rate hyperparameter\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs): # loop over the data repeatedly\n", " for ii in range((n - 1) // bs + 1): # in batches of size bs, so roughly n / bs of them\n", " start_idx = ii * bs # we are ii batches in, each of size bs\n", " end_idx = start_idx + bs # and we want the next bs entires\n", "\n", " # pull batches from x and from y\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", "\n", " # run model\n", " pred = model(xb)\n", "\n", " # get loss\n", " loss = loss_func(pred, yb)\n", "\n", " # calculate the gradients with a backwards pass\n", " loss.backward()\n", "\n", " # update the parameters\n", " with torch.no_grad(): # we don't want to track gradients through this part!\n", " # SGD learning rule: update with negative gradient scaled by lr\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", "\n", " # ACHTUNG: PyTorch doesn't assume you're done with gradients\n", " # until you say so -- by explicitly \"deleting\" them,\n", " # i.e. setting the gradients to 0.\n", " weights.grad.zero_()\n", " bias.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": { "id": "9J-BfH1e_Jkx" }, "source": [ "To check whether things are working,\n", "we confirm that the value of the `loss` has gone down\n", "and the `accuracy` has gone up:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mHgGCLaVJ3yE" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "E1ymEPYdcRHO" }, "source": [ "We can also run the model on a few examples\n", "to get a sense for how it's doing --\n", "always good for detecting bugs in our evaluation metrics!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O88PWejlcSTL" }, "outputs": [], "source": [ "# re-execute this cell for more samples\n", "idx = random.randint(0, len(x_train))\n", "example = x_train[idx:idx+1]\n", "\n", "out = model(example)\n", "\n", "print(out.argmax())\n", "wandb.Image(example.reshape(28, 28)).image" ] }, { "cell_type": "markdown", "metadata": { "id": "7L1Gq1N_J3yE" }, "source": [ "# Refactoring with core `torch.nn` components" ] }, { "cell_type": "markdown", "metadata": { "id": "EE5nUXMG_Yry" }, "source": [ "This works!\n", "But it's rather tedious and manual --\n", "we have to track what the parameters of our model are,\n", "apply the parameter updates to each one individually ourselves,\n", "iterate over the dataset directly, etc.\n", "\n", "It's also very literal:\n", "many assumptions about our problem are hard-coded in the loop.\n", "If our dataset was, say, stored in CSV files\n", "and too large to fit in RAM,\n", "we'd have to rewrite most of our training code.\n", "\n", "For the next few sections,\n", "we'll progressively refactor this code to\n", "make it shorter, cleaner,\n", "and more extensible\n", "using tools from the sublibraries of PyTorch:\n", "`torch.nn`, `torch.optim`, and `torch.utils.data`." ] }, { "cell_type": "markdown", "metadata": { "id": "BHEixRsbJ3yF" }, "source": [ "## Using `torch.nn.functional` for stateless computation" ] }, { "cell_type": "markdown", "metadata": { "id": "9k94IlN58lWa" }, "source": [ "First, let's drop that `cross_entropy` and `log_softmax`\n", "we implemented ourselves --\n", "whenever you find yourself implementing basic mathematical operations\n", "in PyTorch code you want to put in production,\n", "take a second to check whether the code you need's not out\n", "there in a library somewhere.\n", "You'll get fewer bugs and faster code for less effort!" ] }, { "cell_type": "markdown", "metadata": { "id": "sP-giy1a9Ct4" }, "source": [ "Both of those functions operated on their inputs\n", "without reference to any global variables,\n", "so we find their implementation in `torch.nn.functional`,\n", "where stateless computations live." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vfWyJW1sJ3yF" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "loss_func = F.cross_entropy\n", "\n", "def model(xb):\n", " return xb @ weights + bias" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqYIkcvpJ3yF" }, "outputs": [], "source": [ "print(loss_func(model(xb), yb), accuracy(model(xb), yb)) # should be unchanged from above!" ] }, { "cell_type": "markdown", "metadata": { "id": "vXFyM1tKJ3yF" }, "source": [ "## Using `torch.nn.Module` to define functions whose state is given by `torch.nn.Parameter`s" ] }, { "cell_type": "markdown", "metadata": { "id": "PInL-9sbCKnv" }, "source": [ "Perhaps the biggest issue with our setup is how we're handling state.\n", "\n", "The `model` function refers to two global variables: `weights` and `bias`.\n", "These variables are critical for it to run,\n", "but they are defined outside of the function\n", "and are manipulated willy-nilly by other operations.\n", "\n", "This problem arises because of a fundamental tension in\n", "deep neural networks.\n", "We want to use them _as functions_ --\n", "when the time comes to make predictions in production,\n", "we put inputs in and get outputs out,\n", "just like any other function.\n", "But neural networks are fundamentally stateful,\n", "because they are _parameterized_ functions,\n", "and fiddling with the values of those parameters\n", "is the purpose of optimization.\n", "\n", "PyTorch's solution to this is the `nn.Module` class:\n", "a Python class that is callable like a function\n", "but tracks state like an object.\n", "\n", "Whatever `Tensor`s representing state we want PyTorch\n", "to track for us inside of our model\n", "get defined as `nn.Parameter`s and attached to the model\n", "as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A34hxhd0J3yF" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "\n", "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__() # the nn.Module.__init__ method does import setup, so this is mandatory\n", " self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))\n", " self.bias = nn.Parameter(torch.zeros(10))" ] }, { "cell_type": "markdown", "metadata": { "id": "pFD_sIRaFbbx" }, "source": [ "We define the computation that uses that state\n", "in the `.forward` method.\n", "\n", "Using some behind-the-scenes magic,\n", "this method gets called if we treat\n", "the instantiated `nn.Module` like a function by\n", "passing it arguments.\n", "You can give similar special powers to your own classes\n", "by defining `__call__` \"magic dunder\" method\n", "on them.\n", "\n", "> We've separated the definition of the `.forward` method\n", "from the definition of the class above and\n", "attached the method to the class manually below.\n", "We only do this to make the construction of the class\n", "easier to read and understand in the context this notebook --\n", "a neat little trick we'll use a lot in these labs.\n", "Normally, we'd just define the `nn.Module` all at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QAKK3dlFT9w" }, "outputs": [], "source": [ "def forward(self, xb: torch.Tensor) -> torch.Tensor:\n", " return xb @ self.weights + self.bias\n", "\n", "MNISTLogistic.forward = forward\n", "\n", "model = MNISTLogistic() # instantiated as an object\n", "print(model(xb)[:4]) # callable like a function\n", "loss = loss_func(model(xb), yb) # composable like a function\n", "loss.backward() # we can still take gradients through it\n", "print(model.weights.grad[::17,::2]) # and they show up in the .grad attribute" ] }, { "cell_type": "markdown", "metadata": { "id": "r-Yy2eYTHMVl" }, "source": [ "But how do we apply our updates?\n", "Do we need to access `model.weights.grad` and `model.weights`,\n", "like we did in our first implementation?\n", "\n", "Luckily, we don't!\n", "We can iterate over all of our model's `torch.nn.Parameters`\n", "via the `.parameters` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vM59vE-5JiXV" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tbFCdWBkNft0" }, "source": [ "That means we no longer need to assume we know the names\n", "of the model's parameters when we do our update --\n", "we can reuse the same loop with different models." ] }, { "cell_type": "markdown", "metadata": { "id": "hA925fIUK0gg" }, "source": [ "Let's wrap all of that up into a single function to `fit` our model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q9NxJZTOJ3yG" }, "outputs": [], "source": [ "def fit():\n", " for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters(): # finds params automatically\n", " p -= p.grad * lr\n", " model.zero_grad()\n", "\n", "fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjmsb94mK8po" }, "source": [ "and check that we didn't break anything,\n", "i.e. that our model still gets accuracy much higher than 10%:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vo65cLS5J3yH" }, "outputs": [], "source": [ "print(accuracy(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "fxYq2sCLJ3yI" }, "source": [ "# Refactoring intermediate `torch.nn` components: network layers, optimizers, and data handling" ] }, { "cell_type": "markdown", "metadata": { "id": "95c67wZCMynl" }, "source": [ "Our model's state is being handled respectably,\n", "our fitting loop is 2x shorter,\n", "and we can train different models if we'd like.\n", "\n", "But we're not done yet!\n", "Many steps we're doing manually above\n", "are already built in to `torch`." ] }, { "cell_type": "markdown", "metadata": { "id": "CE2VFjDZJ3yI" }, "source": [ "## Using `torch.nn.Linear` for the model definition" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvcnrz2uJ3yI" }, "source": [ "As with our hand-rolled `cross_entropy`\n", "that could be profitably replaced with\n", "the industrial grade `nn.functional.cross_entropy`,\n", "we should replace our bespoke linear layer\n", "with something made by experts.\n", "\n", "Instead of defining `nn.Parameters`,\n", "effectively raw `Tensor`s, as attributes\n", "of our `nn.Module`,\n", "we can define other `nn.Module`s as attributes.\n", "PyTorch assigns the `nn.Parameters`\n", "of any child `nn.Module`s to the parent, recursively.\n", "\n", "These `nn.Module`s are reusable --\n", "say, if we want to make a network with multiple layers of the same type --\n", "and there are lots of them already defined:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l-EKdhXcPjq2" }, "outputs": [], "source": [ "import textwrap\n", "\n", "print(\"torch.nn.Modules:\", *textwrap.wrap(\", \".join(torch.nn.modules.__all__)), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "KbIIQMaBQC45" }, "source": [ "We want the humble `nn.Linear`,\n", "which applies the same\n", "matrix multiplication and bias operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JHwS-1-rJ3yJ" }, "outputs": [], "source": [ "class MNISTLogistic(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lin = nn.Linear(784, 10) # pytorch finds the nn.Parameters inside this nn.Module\n", "\n", " def forward(self, xb):\n", " return self.lin(xb) # call nn.Linear.forward here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Mcb0UvcmJ3yJ" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "print(loss_func(model(xb), yb)) # loss is still close to 2.3" ] }, { "cell_type": "markdown", "metadata": { "id": "5hcjV8A2QjQJ" }, "source": [ "We can see that the `nn.Linear` module is a \"child\"\n", "of the `model`,\n", "and we don't see the matrix of weights and the bias vector:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yKkU-GIPOQq4" }, "outputs": [], "source": [ "print(*list(model.children()))" ] }, { "cell_type": "markdown", "metadata": { "id": "kUdhpItWQui_" }, "source": [ "but if we ask for the model's `.parameters`,\n", "we find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G1yGOj2LNDsS" }, "outputs": [], "source": [ "print(*list(model.parameters()), sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DFlQyKl6J3yJ" }, "source": [ "## Applying gradients with `torch.optim.Optimizer`" ] }, { "cell_type": "markdown", "metadata": { "id": "IqImMaenJ3yJ" }, "source": [ "Applying gradients to optimize parameters\n", "and resetting those gradients to zero\n", "are very common operations.\n", "\n", "So why are we doing that by hand?\n", "Now that our model is a `torch.nn.Module` using `torch.nn.Parameters`,\n", "we don't have to --\n", "we just need to point a `torch.optim.Optimizer`\n", "at the parameters of our model.\n", "\n", "While we're at it, we can also use a more sophisticated optimizer --\n", "`Adam` is a common first choice." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f5AUNLEKJ3yJ" }, "outputs": [], "source": [ "from torch import optim\n", "\n", "\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jK9dy0sNJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "print(\"before training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(\"after training:\", loss_func(model(xb), yb), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "4yk9re3HJ3yK" }, "source": [ "## Organizing data with `torch.utils.data.Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "0ap3fcZpTIqJ" }, "source": [ "We're also manually handling the data.\n", "First, we're independently and manually aligning\n", "the inputs, `x_train`, and the outputs, `y_train`.\n", "\n", "Aligned data is important in ML.\n", "We want a way to combine multiple data sources together\n", "and index into them simultaneously.\n", "\n", "That's done with `torch.utils.data.Dataset`.\n", "Just inherit from it and implement two methods to support indexing:\n", "`__getitem__` and `__len__`." ] }, { "cell_type": "markdown", "metadata": { "id": "HPj25nkoVWRi" }, "source": [ "We'll cheat a bit here and pull in the `BaseDataset`\n", "class from the `text_recognizer` library,\n", "so that we can start getting some exposure\n", "to the codebase for the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NpltQ-4JJ3yK" }, "outputs": [], "source": [ "from text_recognizer.data.util import BaseDataset\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV1bc4R5Vz0N" }, "source": [ "The cell below will pull up the documentation for this class,\n", "which effectively just indexes into the two `Tensor`s simultaneously.\n", "\n", "It can also apply transformations to the inputs and targets.\n", "We'll see that later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XUWJ8yIWU28G" }, "outputs": [], "source": [ "BaseDataset??" ] }, { "cell_type": "markdown", "metadata": { "id": "zMQDHJNzWMtf" }, "source": [ "This makes our code a tiny bit cleaner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6iyqG4kEJ3yK" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "opt = configure_optimizer(model)\n", "\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " xb, yb = train_ds[ii * bs: ii * bs + bs] # xb and yb in one line!\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "pTtRPp_iJ3yL" }, "source": [ "## Batching up data with `torch.utils.data.DataLoader`" ] }, { "cell_type": "markdown", "metadata": { "id": "FPnaMyokWSWv" }, "source": [ "We're also still manually building our batches.\n", "\n", "Making batches out of datasets is a core component of contemporary deep learning training workflows,\n", "so unsurprisingly PyTorch offers a tool for it: the `DataLoader`.\n", "\n", "We just need to hand our `Dataset` to the `DataLoader`\n", "and choose a `batch_size`.\n", "\n", "We can tune that parameter and other `DataLoader` arguments,\n", "like `num_workers` and `pin_memory`,\n", "to improve the performance of our training loop.\n", "For more on the impact of `DataLoader` parameters on the behavior of PyTorch code, see\n", "[this blog post and Colab](https://wandb.ai/wandb/trace/reports/A-Public-Dissection-of-a-PyTorch-Training-Step--Vmlldzo5MDE3NjU)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aqXX7JGCJ3yL" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iWry2CakJ3yL" }, "outputs": [], "source": [ "def fit(self: nn.Module, train_dataloader: DataLoader):\n", " opt = configure_optimizer(self)\n", "\n", " for epoch in range(epochs):\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", "MNISTLogistic.fit = fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9pfdSJBIXT8o" }, "outputs": [], "source": [ "model = MNISTLogistic()\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "RAs8-3IfJ3yL" }, "source": [ "Compare the ten line `fit` function with our first training loop (reproduced below) --\n", "much cleaner _and_ much more powerful!" ] }, { "cell_type": "markdown", "metadata": { "id": "_a51dZrLJ3yL" }, "source": [ "```python\n", "lr = 0.5 # learning rate\n", "epochs = 2 # how many epochs to train for\n", "\n", "for epoch in range(epochs):\n", " for ii in range((n - 1) // bs + 1):\n", " start_idx = ii * bs\n", " end_idx = start_idx + bs\n", " xb = x_train[start_idx:end_idx]\n", " yb = y_train[start_idx:end_idx]\n", " pred = model(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " with torch.no_grad():\n", " weights -= weights.grad * lr\n", " bias -= bias.grad * lr\n", " weights.grad.zero_()\n", " bias.grad.zero_()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jiQe3SEWyZo4" }, "source": [ "## Swapping in another model" ] }, { "cell_type": "markdown", "metadata": { "id": "KykHpZEWyZo4" }, "source": [ "To see that our new `.fit` is more powerful,\n", "let's use it with a different model.\n", "\n", "Specifically, let's draw in the `MLP`,\n", "or \"multi-layer perceptron\" model\n", "from the `text_recognizer` library\n", "in our codebase." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1FtGJg1CyZo4" }, "outputs": [], "source": [ "from text_recognizer.models.mlp import MLP\n", "\n", "\n", "MLP.fit = fit # attach our fitting loop" ] }, { "cell_type": "markdown", "metadata": { "id": "kJiP3a-8yZo4" }, "source": [ "If you look in the `.forward` method of the `MLP`,\n", "you'll see that it uses\n", "some modules and functions we haven't seen, like\n", "[`nn.Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html),\n", "but otherwise fits the interface of our training loop:\n", "the `MLP` is callable and it takes an `x` and returns a guess for the `y` labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hj-0UdJwyZo4" }, "outputs": [], "source": [ "MLP.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "FS7dxQ4VyZo4" }, "source": [ "If we look at the constructor, `__init__`,\n", "we see that the `nn.Module`s (`fc` and `dropout`)\n", "are initialized and attached as attributes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x0NpkeA8yZo5" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "Uygy5HsUyZo5" }, "source": [ "We also see that we are required to provide a `data_config`\n", "dictionary and can optionally configure the module with `args`.\n", "\n", "For now, we'll only do the bare minimum and specify\n", "the contents of the `data_config`:\n", "the `input_dims` for `x` and the `mapping`\n", "from class index in `y` to class label,\n", "which we can see are used in the `__init__` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y6BEl_I-yZo5" }, "outputs": [], "source": [ "digits_to_9 = list(range(10))\n", "data_config = {\"input_dims\": (784,), \"mapping\": {digit: str(digit) for digit in digits_to_9}}\n", "data_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bEuNc38JyZo5" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model" ] }, { "cell_type": "markdown", "metadata": { "id": "CWQK2DWWyZo6" }, "source": [ "The resulting `MLP` is a bit larger than our `MNISTLogistic` model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zs1s6ahUyZo8" }, "outputs": [], "source": [ "model.fc1.weight" ] }, { "cell_type": "markdown", "metadata": { "id": "JVLkK78FyZo8" }, "source": [ "But that doesn't matter for our fitting loop,\n", "which happily optimizes this model on batches from the `train_dataloader`,\n", "though it takes a bit longer." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-DItXLoyZo9" }, "outputs": [], "source": [ "%%time\n", "\n", "print(\"before training:\", loss_func(model(xb), yb))\n", "\n", "train_ds = BaseDataset(x_train, y_train)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)\n", "fit(model, train_dataloader)\n", "\n", "print(\"after training:\", loss_func(model(xb), yb))" ] }, { "cell_type": "markdown", "metadata": { "id": "9QgTv2yzJ3yM" }, "source": [ "# Extra goodies: data organization, validation, and acceleration" ] }, { "cell_type": "markdown", "metadata": { "id": "Vx-CcCesbmyw" }, "source": [ "Before we've got a DNN fitting loop that's welcome in polite company,\n", "we need three more features:\n", "organized data loading code, validation, and GPU acceleration." ] }, { "cell_type": "markdown", "metadata": { "id": "8LWja5aDJ3yN" }, "source": [ "## Making the GPU go brrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "7juxQ_Kp-Tx0" }, "source": [ "Everything we've done so far has been on\n", "the central processing unit of the computer, or CPU.\n", "When programming in Python,\n", "it is on the CPU that\n", "almost all of our code becomes concrete instructions\n", "that cause a machine move around electrons." ] }, { "cell_type": "markdown", "metadata": { "id": "R25L3z8eAWIO" }, "source": [ "That's okay for small-to-medium neural networks,\n", "but computation quickly becomes a bottleneck that makes achieving\n", "good performance infeasible.\n", "\n", "In general, the problem of CPUs,\n", "which are general purpose computing devices,\n", "being too slow is solved by using more specialized accelerator chips --\n", "in the extreme case, application-specific integrated circuits (ASICs)\n", "that can only perform a single task,\n", "the hardware equivalents of\n", "[sword-billed hummingbirds](https://en.wikipedia.org/wiki/Sword-billed_hummingbird) or\n", "[Canada lynx](https://en.wikipedia.org/wiki/Canada_lynx).\n", "\n", "Luckily, really excellent chips\n", "for accelerating deep learning are readily available\n", "as a consumer product:\n", "graphics processing units (GPUs),\n", "which are designed to perform large matrix multiplications in parallel.\n", "Their name derives from their origins\n", "applying large matrix multiplications to manipulate shapes and textures\n", "in for graphics engines for video games and CGI.\n", "\n", "If your system has a GPU and the right libraries installed\n", "for `torch` compatibility,\n", "the cell below will print information about its state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xxy-Gt9wJ3yN" }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " !nvidia-smi\n", "else:\n", " print(\"☹️\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x6qAX1OECiWk" }, "source": [ "PyTorch is designed to allow for computation to occur both on the CPU and the GPU --\n", "even simultaneously, which can be critical for high performance.\n", "\n", "So once we start using acceleration, we need to be more precise about where the\n", "data inside our `Tensor`s lives --\n", "on which physical `torch.device` it can be found.\n", "\n", "On compatible systems, the cell below will\n", "move all of the model's parameters `.to` the GPU\n", "(another good reason to use `torch.nn.Parameter`s and not handle them yourself!)\n", "and then move a batch of inputs and targets there as well\n", "before applying the model and calculating the loss.\n", "\n", "To confirm this worked, look for the name of the device in the output of the cell,\n", "alongside other information about the loss `Tensor`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jGkpfEmbJ3yN" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "model.to(device)\n", "\n", "loss_func(model(xb.to(device)), yb.to(device))" ] }, { "cell_type": "markdown", "metadata": { "id": "-zdPR06eDjIX" }, "source": [ "Rather than rewrite our entire `.fit` function,\n", "we'll make use of the features of the `text_recognizer.data.utils.BaseDataset`.\n", "\n", "Specifically,\n", "we can provide a `transform` that is called on the inputs\n", "and a `target_transform` that is called on the labels\n", "before they are returned.\n", "In the FSDL codebase,\n", "this feature is used for data preparation, like\n", "reshaping, resizing,\n", "and normalization.\n", "\n", "We'll use this as an opportunity to put the `Tensor`s on the appropriate device." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m8WQS9Zo_Did" }, "outputs": [], "source": [ "def push_to_device(tensor):\n", " return tensor.to(device)\n", "\n", "train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", "train_dataloader = DataLoader(train_ds, batch_size=bs)" ] }, { "cell_type": "markdown", "metadata": { "id": "nmg9HMSZFmqR" }, "source": [ "We don't need to change anything about our fitting code to run it on the GPU!\n", "\n", "Note: given the small size of this model and the data,\n", "the speedup here can sometimes be fairly moderate (like 2x).\n", "For larger models, GPU acceleration can easily lead to 50-100x faster iterations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "v1TVc06NkXrU" }, "outputs": [], "source": [ "%%time\n", "\n", "model = MLP(data_config)\n", "model.to(device)\n", "\n", "model.fit(train_dataloader)\n", "\n", "print(loss_func(model(push_to_device(xb)), push_to_device(yb)))" ] }, { "cell_type": "markdown", "metadata": { "id": "L7thbdjKTjAD" }, "source": [ "Writing high performance GPU-accelerated neural network code is challenging.\n", "There are many sharp edges, so the default\n", "strategy is imitation (basing all work on existing verified quality code)\n", "and conservatism bordering on paranoia about change.\n", "For a casual introduction to some of the core principles, see\n", "[Horace He's blogpost](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "LnpbEVE5J3yM" }, "source": [ "## Adding validation data and organizing data code with a `DataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "EqYHjiG8b_4J" }, "source": [ "Just doing well on data you've seen before is not that impressive --\n", "the network could just memorize the label for each input digit.\n", "\n", "We need to check performance on a set of data points that weren't used\n", "directly to optimize the model,\n", "commonly called the validation set." ] }, { "cell_type": "markdown", "metadata": { "id": "7e6z-Fh8dOnN" }, "source": [ "We already downloaded one up above,\n", "but that was all the way at the beginning of the notebook,\n", "and I've already forgotten about it.\n", "\n", "In general, it's easy for data-loading code,\n", "the redheaded stepchild of the ML codebase,\n", "to become messy and fall out of sync.\n", "\n", "A proper `DataModule` collects up all of the code required\n", "to prepare data on a machine,\n", "sets it up as a collection of `Dataset`s,\n", "and turns those `Dataset`s into `DataLoader`s,\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0WxgRa2GJ3yM" }, "outputs": [], "source": [ "class MNISTDataModule:\n", " url = \"https://github.com/pytorch/tutorials/raw/master/_static/\"\n", " filename = \"mnist.pkl.gz\"\n", " \n", " def __init__(self, dir, bs=32):\n", " self.dir = dir\n", " self.bs = bs\n", " self.path = self.dir / self.filename\n", "\n", " def prepare_data(self):\n", " if not (self.path).exists():\n", " content = requests.get(self.url + self.filename).content\n", " self.path.open(\"wb\").write(content)\n", "\n", " def setup(self):\n", " with gzip.open(self.path, \"rb\") as f:\n", " ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")\n", "\n", " x_train, y_train, x_valid, y_valid = map(\n", " torch.tensor, (x_train, y_train, x_valid, y_valid)\n", " )\n", " \n", " self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)\n", " self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)\n", " \n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "x-8T_MlWifMe" }, "source": [ "We'll cover `DataModule`s in more detail later.\n", "\n", "We can now incorporate our `DataModule`\n", "into the fitting pipeline\n", "by calling its methods as needed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mcFcbRhSJ3yN" }, "outputs": [], "source": [ "def fit(self: nn.Module, datamodule):\n", " datamodule.prepare_data()\n", " datamodule.setup()\n", "\n", " val_dataloader = datamodule.val_dataloader()\n", " \n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(\"before start of training:\", valid_loss / len(val_dataloader))\n", "\n", " opt = configure_optimizer(self)\n", " train_dataloader = datamodule.train_dataloader()\n", " for epoch in range(epochs):\n", " self.train()\n", " for xb, yb in train_dataloader:\n", " pred = self(xb)\n", " loss = loss_func(pred, yb)\n", "\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", "\n", " self.eval()\n", " with torch.no_grad():\n", " valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)\n", "\n", " print(epoch, valid_loss / len(val_dataloader))\n", "\n", "\n", "MNISTLogistic.fit = fit\n", "MLP.fit = fit" ] }, { "cell_type": "markdown", "metadata": { "id": "-Uqey9w6jkv9" }, "source": [ "Now we've substantially cut down on the \"hidden state\" in our fitting code:\n", "if you've defined the `MNISTLogistic` and `MNISTDataModule` classes,\n", "then you can train a network with just the cell below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uxN1yV6DX6Nz" }, "outputs": [], "source": [ "model = MLP(data_config)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=32)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "2zHA12Iih0ML" }, "source": [ "You may have noticed a few other changes in the `.fit` method:\n", "\n", "- `self.eval` vs `self.train`:\n", "it's helpful to have features of neural networks that behave differently in `train`ing\n", "than they do in production or `eval`uation.\n", "[Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)\n", "and\n", "[BatchNorm](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html)\n", "are among the most popular examples.\n", "We need to take this into account now that we\n", "have a validation loop.\n", "- The return of `torch.no_grad`: in our first few implementations,\n", "we had to use `torch.no_grad` to avoid tracking gradients while we were updating parameters.\n", "Now, we need to use it to avoid tracking gradients during validation." ] }, { "cell_type": "markdown", "metadata": { "id": "BaODkqTnJ3yO" }, "source": [ "This is starting to get a bit hairy again!\n", "We're back up to about 30 lines of code,\n", "right where we started\n", "(but now with way more features!).\n", "\n", "Much like `torch.nn` provides useful tools and interfaces for\n", "defining neural networks,\n", "iterating over batches,\n", "and calculating gradients,\n", "frameworks on top of PyTorch, like\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/),\n", "provide useful tools and interfaces\n", "for an even higher level of abstraction over neural network training.\n", "\n", "For serious deep learning codebases,\n", "you'll want to use a framework at that level of abstraction --\n", "either one of the popular open frameworks or one developed in-house.\n", "\n", "For most of these frameworks,\n", "you'll still need facility with core PyTorch:\n", "at least for defining models and\n", "often for defining data pipelines as well." ] }, { "cell_type": "markdown", "metadata": { "id": "-4piIilkyZpD" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "E482VfIlyZpD" }, "source": [ "### 🌟 Try out different hyperparameters for the `MLP` and for training." ] }, { "cell_type": "markdown", "metadata": { "id": "IQ8bkAxNyZpD" }, "source": [ "The `MLP` class is configured via the `args` argument to its constructor,\n", "which can set the values of hyperparameters like the width of layers and the degree of dropout:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Tl-AvMVyZpD" }, "outputs": [], "source": [ "MLP.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "0HfbQ0KkyZpD" }, "source": [ "As the type signature indicates, `args` is an `argparse.Namespace`.\n", "[`argparse` is used to build command line interfaces in Python](https://realpython.com/command-line-interfaces-python-argparse/),\n", "and later on we'll see how to configure models\n", "and launch training jobs from the command line\n", "in the FSDL codebase.\n", "\n", "For now, we'll do it by hand, by passing a dictionary to `Namespace`.\n", "\n", "Edit the cell below to change the `args`, `epochs`, and `b`atch `s`ize.\n", "\n", "Can you get a final `valid`ation `acc`uracy of 98%?\n", "Can you get to 95% 2x faster than the baseline `MLP`?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-vVtGJhtyZpD" }, "outputs": [], "source": [ "%%time \n", "from argparse import Namespace # you'll need this\n", "\n", "args = None # edit this\n", "\n", "epochs = 2 # used in fit\n", "bs = 32 # used by the DataModule\n", "\n", "\n", "# used in fit, play around with this if you'd like\n", "def configure_optimizer(model: nn.Module) -> optim.Optimizer:\n", " return optim.Adam(model.parameters(), lr=3e-4)\n", "\n", "\n", "model = MLP(data_config, args=args)\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7yyxc3uxyZpD" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] }, { "cell_type": "markdown", "metadata": { "id": "0ZHygZtgyZpE" }, "source": [ "### 🌟🌟🌟 Write your own `nn.Module`." ] }, { "cell_type": "markdown", "metadata": { "id": "r3Iu73j3yZpE" }, "source": [ "Designing new models is one of the most fun\n", "aspects of building an ML-powered application.\n", "\n", "Can you make an `nn.Module` that looks different from\n", "the standard `MLP` but still gets 98% validation accuracy or higher?\n", "You might start from the `MLP` and\n", "[add more layers to it](https://i.imgur.com/qtlP5LI.png)\n", "while adding more bells and whistles.\n", "Take care to keep the shapes of the `Tensor`s aligned as you go.\n", "\n", "Here's some tricks you can try that are especially helpful with deeper networks:\n", "- Add [`BatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html)\n", "layers, which can improve\n", "[training stability and loss conditioning](https://myrtle.ai/how-to-train-your-resnet-7-batch-norm/)\n", "- Add a linear \"skip connection\" layer that is applied to the inputs and whose outputs are added directly to the last layer's outputs\n", "- Use other [activation functions](https://pytorch.org/docs/stable/nn.functional.html#non-linear-activation-functions),\n", "like [selu](https://pytorch.org/docs/stable/generated/torch.nn.functional.selu.html)\n", "or [mish](https://pytorch.org/docs/stable/generated/torch.nn.functional.mish.html)\n", "\n", "If you want to make an `nn.Module` that can have different depths,\n", "check out the\n", "[`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JsF_RfrDyZpE" }, "outputs": [], "source": [ "class YourModel(nn.Module):\n", " def __init__(self): # add args and kwargs here as you like\n", " super().__init__()\n", " # use those args and kwargs to set up the submodules\n", " self.ps = nn.Parameter(torch.zeros(10))\n", "\n", " def forward(self, xb): # overwrite this to use your nn.Modules from above\n", " xb = torch.stack([self.ps for ii in range(len(xb))])\n", " return xb\n", " \n", " \n", "YourModel.fit = fit # don't forget this!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6OQidtGyZpE" }, "outputs": [], "source": [ "model = YourModel()\n", "model.to(device)\n", "\n", "datamodule = MNISTDataModule(dir=path, bs=bs)\n", "\n", "model.fit(datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CH0U4ODoyZpE" }, "outputs": [], "source": [ "val_dataloader = datamodule.val_dataloader()\n", "valid_acc = sum(accuracy(model(xb), yb) for xb, yb in val_dataloader) / len(val_dataloader)\n", "valid_acc" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab01_pytorch.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab08/notebooks/lab02a_lightning.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02a: PyTorch Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The core components of a PyTorch Lightning training loop: `LightningModule`s and `Trainer`s.\n", "- Useful quality-of-life improvements offered by PyTorch Lightning: `LightningDataModule`s, `Callback`s, and `Metric`s\n", "- How we use these features in the FSDL codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Lightning?" ] }, { "cell_type": "markdown", "metadata": { "id": "bP8iJW_bg7IC" }, "source": [ "PyTorch is a powerful library for executing differentiable\n", "tensor operations with hardware acceleration\n", "and it includes many neural network primitives,\n", "but it has no concept of \"training\".\n", "At a high level, an `nn.Module` is a stateful function with gradients\n", "and a `torch.optim.Optimizer` can update that state using gradients,\n", "but there's no pre-built tools in PyTorch to iteratively generate those gradients from data." ] }, { "cell_type": "markdown", "metadata": { "id": "a7gIA-Efy91E" }, "source": [ "So the first thing many folks do in PyTorch is write that code --\n", "a \"training loop\" to iterate over their `DataLoader`,\n", "which in pseudocode might look something like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Y3ewkWrwzDA8" }, "source": [ "```python\n", "for batch in dataloader:\n", " inputs, targets = batch\n", "\n", " outputs = model(inputs)\n", " loss = some_loss_function(targets, outputs)\n", " \n", " optimizer.zero_gradients()\n", " loss.backward()\n", "\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "OYUtiJWize82" }, "source": [ "This is a solid start, but other needs immediately arise.\n", "You'll want to run your model on validation and test data,\n", "which need their own `DataLoader`s.\n", "Once finished, you'll want to save your model --\n", "and for long-running jobs, you probably want\n", "to save checkpoints of the training process\n", "so that it can be resumed in case of a crash.\n", "For state-of-the-art model performance in many domains,\n", "you'll want to distribute your training across multiple nodes/machines\n", "and across multiple GPUs within those nodes." ] }, { "cell_type": "markdown", "metadata": { "id": "0untumvjy5fm" }, "source": [ "That's just the tip of the iceberg, and you want\n", "all those features to work for lots of models and datasets,\n", "not just the one you're writing now." ] }, { "cell_type": "markdown", "metadata": { "id": "TNPpi4OZjMbu" }, "source": [ "You don't want to write all of this yourself.\n", "\n", "So unless you are at a large organization that has a dedicated team\n", "for building that \"framework\" code,\n", "you'll want to use an existing library." ] }, { "cell_type": "markdown", "metadata": { "id": "tnQuyVqUjJy8" }, "source": [ "PyTorch Lightning is a popular framework on top of PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ecipNFTgZDt" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "version = pl.__version__\n", "\n", "docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/\" # version can also be latest, stable\n", "docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "bE82xoEikWkh" }, "source": [ "At its core, PyTorch Lightning provides\n", "\n", "1. the `pl.Trainer` class, which organizes and executes your training, validation, and test loops, and\n", "2. the `pl.LightningModule` class, which links optimizers to models and defines how the model behaves during training, validation, and testing.\n", "\n", "Both of these are kitted out with all the features\n", "a cutting-edge deep learning codebase needs:\n", "- flags for switching device types and distributed computing strategy\n", "- saving, checkpointing, and resumption\n", "- calculation and logging of metrics\n", "\n", "and much more.\n", "\n", "Importantly these features can be easily\n", "added, removed, extended, or bypassed\n", "as desired, meaning your code isn't constrained by the framework." ] }, { "cell_type": "markdown", "metadata": { "id": "uuJUDmCeT3RK" }, "source": [ "In some ways, you can think of Lightning as a tool for \"organizing\" your PyTorch code,\n", "as shown in the video below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wTt0TBs5TZpm" }, "outputs": [], "source": [ "import IPython.display as display\n", "\n", "\n", "display.IFrame(src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_mod_vid.m4v\",\n", " width=720, height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "CGwpDn5GWn_X" }, "source": [ "That's opposed to the other way frameworks are designed,\n", "to provide abstractions over the lower-level library\n", "(here, PyTorch).\n", "\n", "Because of this \"organize don't abstract\" style,\n", "writing PyTorch Lightning code involves\n", "a lot of over-riding of methods --\n", "you inherit from a class\n", "and then implement the specific version of a general method\n", "that you need for your code,\n", "rather than Lightning providing a bunch of already\n", "fully-defined classes that you just instantiate,\n", "using arguments for configuration." ] }, { "cell_type": "markdown", "metadata": { "id": "TXiUcQwan39S" }, "source": [ "# The `pl.LightningModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "_3FffD5Vn6we" }, "source": [ "The first of our two core classes,\n", "the `LightningModule`,\n", "is like a souped-up `torch.nn.Module` --\n", "it inherits all of the `Module` features,\n", "but adds more." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QWwSStJTP28" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "issubclass(pl.LightningModule, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "q1wiBVSTuHNT" }, "source": [ "To demonstrate how this class works,\n", "we'll build up a `LinearRegression` model dynamically,\n", "method by method.\n", "\n", "For this example we hard code lots of the details,\n", "but the real benefit comes when the details are configurable.\n", "\n", "In order to have a realistic example as well,\n", "we'll compare to the actual code\n", "in the `BaseLitModel` we use in the codebase\n", "as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fPARncfQ3ohz" }, "outputs": [], "source": [ "from text_recognizer.lit_models import BaseLitModel" ] }, { "cell_type": "markdown", "metadata": { "id": "myyL0vYU3z0a" }, "source": [ "A `pl.LightningModule` is a `torch.nn.Module`,\n", "so the basic definition looks the same:\n", "we need `__init__` and `forward`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-c0ylFO9rW_t" }, "outputs": [], "source": [ "class LinearRegression(pl.LightningModule):\n", "\n", " def __init__(self):\n", " super().__init__() # just like in torch.nn.Module, we need to call the parent class __init__\n", "\n", " # attach torch.nn.Modules as top level attributes during init, just like in a torch.nn.Module\n", " self.model = torch.nn.Linear(in_features=1, out_features=1)\n", " # we like to define the entire model as one torch.nn.Module -- typically in a separate class\n", "\n", " # optionally, define a forward method\n", " def forward(self, xs):\n", " return self.model(xs) # we like to just call the model's forward method" ] }, { "cell_type": "markdown", "metadata": { "id": "ZY1yoGTy6CBu" }, "source": [ "But just the minimal definition for a `torch.nn.Module` isn't sufficient.\n", "\n", "If we try to use the class above with the `Trainer`, we get an error:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tBWh_uHu5rmU" }, "outputs": [], "source": [ "import logging # import some stdlib components to control what's display\n", "import textwrap\n", "import traceback\n", "\n", "\n", "try: # try using the LinearRegression LightningModule defined above\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR) # hide some info for now\n", "\n", " model = LinearRegression()\n", "\n", " # we'll explain how the Trainer works in a bit\n", " trainer = pl.Trainer(gpus=int(torch.cuda.is_available()), max_epochs=1)\n", " trainer.fit(model=model) \n", "\n", "except pl.utilities.exceptions.MisconfigurationException as error:\n", " print(\"Error:\", *textwrap.wrap(str(error), 80), sep=\"\\n\\t\") # show the error without raising it\n", "\n", "finally: # bring back info-level logging\n", " logging.getLogger(\"pytorch_lightning\").setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "metadata": { "id": "s5ni7xe5CgUt" }, "source": [ "The error message says we need some more methods.\n", "\n", "Two of them are mandatory components of the `LightningModule`: `.training_step` and `.configure_optimizers`." ] }, { "cell_type": "markdown", "metadata": { "id": "37BXP7nAoBik" }, "source": [ "#### `.training_step`" ] }, { "cell_type": "markdown", "metadata": { "id": "Ah9MjWz2plFv" }, "source": [ "The `training_step` method defines,\n", "naturally enough,\n", "what to do during a single step of training." ] }, { "cell_type": "markdown", "metadata": { "id": "plWEvWG_zRia" }, "source": [ "Roughly, it gets used like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "9RbxZ4idy-C5" }, "source": [ "```python\n", "\n", "# pseudocode modified from the Lightning documentation\n", "\n", "# put model in train mode\n", "model.train()\n", "\n", "for batch in train_dataloader:\n", " # run the train step\n", " loss = training_step(batch)\n", "\n", " # clear gradients\n", " optimizer.zero_grad()\n", "\n", " # backprop\n", " loss.backward()\n", "\n", " # update parameters\n", " optimizer.step()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "cemh_hGJ53nL" }, "source": [ "Effectively, it maps a batch to a loss value,\n", "so that PyTorch can backprop through that loss.\n", "\n", "The `.training_step` for our `LinearRegression` model is straightforward:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X8qW2VRRsPI2" }, "outputs": [], "source": [ "from typing import Tuple\n", "\n", "\n", "def training_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " xs, ys = batch # unpack the batch\n", " outs = self(xs) # apply the model\n", " loss = torch.nn.functional.mse_loss(outs, ys) # compute the (squared error) loss\n", " return loss\n", "\n", "\n", "LinearRegression.training_step = training_step" ] }, { "cell_type": "markdown", "metadata": { "id": "x2e8m3BRCIx6" }, "source": [ "If you've written PyTorch code before, you'll notice that we don't mention devices\n", "or other tensor metadata here -- that's handled for us by Lightning, which is a huge relief." ] }, { "cell_type": "markdown", "metadata": { "id": "FkvNpfwqpns5" }, "source": [ "You can additionally define\n", "a `validation_step` and a `test_step`\n", "to define the model's behavior during\n", "validation and testing loops.\n", "\n", "You're invited to define these steps\n", "in the exercises at the end of the lab.\n", "\n", "Inside this step is also where you might calculate other\n", "values related to inputs, outputs, and loss,\n", "like non-differentiable metrics (e.g. accuracy, precision, recall).\n", "\n", "So our `BaseLitModel`'s got a slightly more complex `training_step` method,\n", "and the details of the forward pass are deferred to `._run_on_batch` instead." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xpBkRczao1hr" }, "outputs": [], "source": [ "BaseLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "guhoYf_NoEyc" }, "source": [ "#### `.configure_optimizers`" ] }, { "cell_type": "markdown", "metadata": { "id": "SCIAWoCEtIU7" }, "source": [ "Thanks to `training_step` we've got a loss, and PyTorch can turn that into a gradient.\n", "\n", "But we need more than a gradient to do an update.\n", "\n", "We need an _optimizer_ that can make use of the gradients to update the parameters. In complex cases, we might need more than one optimizer (e.g. GANs).\n", "\n", "Our second required method, `.configure_optimizers`,\n", "sets up the `torch.optim.Optimizer`s \n", "(e.g. setting their hyperparameters\n", "and pointing them at the `Module`'s parameters)." ] }, { "cell_type": "markdown", "metadata": { "id": "bMlnRdIPzvDF" }, "source": [ "In psuedo-code (modified from the Lightning documentation), it gets used something like this:" ] }, { "cell_type": "markdown", "metadata": { "id": "_WBnfJzszi49" }, "source": [ "```python\n", "optimizer = model.configure_optimizers()\n", "\n", "for batch_idx, batch in enumerate(data):\n", "\n", " def closure(): # wrap the loss calculation\n", " loss = model.training_step(batch, batch_idx, ...)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " return loss\n", "\n", " # optimizer can call the loss calculation as many times as it likes\n", " optimizer.step(closure) # some optimizers need this, like (L)-BFGS\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "SGsP3DBy7YzW" }, "source": [ "For our `LinearRegression` model,\n", "we just need to instantiate an optimizer and point it at the parameters of the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZWrWGgdVt21h" }, "outputs": [], "source": [ "def configure_optimizers(self: LinearRegression) -> torch.optim.Optimizer:\n", " optimizer = torch.optim.Adam(self.parameters(), lr=3e-4) # https://fsdl.me/ol-reliable-img\n", " return optimizer\n", "\n", "\n", "LinearRegression.configure_optimizers = configure_optimizers" ] }, { "cell_type": "markdown", "metadata": { "id": "ta2hs0OLwbtF" }, "source": [ "You can read more about optimization in Lightning,\n", "including how to manually control optimization\n", "instead of relying on default behavior,\n", "in the docs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KXINqlAgwfKy" }, "outputs": [], "source": [ "optimization_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/optimization.html\"\n", "optimization_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "zWdKdZDfxmb2" }, "source": [ "The `configure_optimizers` method for the `BaseLitModel`\n", "isn't that much more complex.\n", "\n", "We just add support for learning rate schedulers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kyRbz0bEpWwd" }, "outputs": [], "source": [ "BaseLitModel.configure_optimizers??" ] }, { "cell_type": "markdown", "metadata": { "id": "ilQCfn7Nm_QP" }, "source": [ "# The `pl.Trainer`" ] }, { "cell_type": "markdown", "metadata": { "id": "RScc0ef97qlc" }, "source": [ "The `LightningModule` has already helped us organize our code,\n", "but it's not really useful until we combine it with the `Trainer`,\n", "which relies on the `LightningModule` interface to execute training, validation, and testing." ] }, { "cell_type": "markdown", "metadata": { "id": "bBdikPBF86Qp" }, "source": [ "The `Trainer` is where we make choices like how long to train\n", "(`max_epochs`, `min_epochs`, `max_time`, `max_steps`),\n", "what kind of acceleration (e.g. `gpus`) or distribution strategy to use,\n", "and other settings that might differ across training runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YQ4KSdFP3E4Q" }, "outputs": [], "source": [ "trainer = pl.Trainer(max_epochs=20, gpus=int(torch.cuda.is_available()))" ] }, { "cell_type": "markdown", "metadata": { "id": "S2l3rGZK7-PL" }, "source": [ "Before we can actually use the `Trainer`, though,\n", "we also need a `torch.utils.data.DataLoader` --\n", "nothing new from PyTorch Lightning here,\n", "just vanilla PyTorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OcUSD2jP4Ffo" }, "outputs": [], "source": [ "class CorrelatedDataset(torch.utils.data.Dataset):\n", "\n", " def __init__(self, N=10_000):\n", " self.N = N\n", " self.xs = torch.randn(size=(N, 1))\n", " self.ys = torch.randn_like(self.xs) + self.xs # correlated target data: y ~ N(x, 1)\n", "\n", " def __getitem__(self, idx):\n", " return (self.xs[idx], self.ys[idx])\n", "\n", " def __len__(self):\n", " return self.N\n", "\n", "\n", "dataset = CorrelatedDataset()\n", "tdl = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "o0u41JtA8qGo" }, "source": [ "We can fetch some sample data from the `DataLoader`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z1j6Gj9Ka0dJ" }, "outputs": [], "source": [ "example_xs, example_ys = next(iter(tdl)) # grabbing an example batch to print\n", "\n", "print(\"xs:\", example_xs[:10], sep=\"\\n\")\n", "print(\"ys:\", example_ys[:10], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Nnqk3mRv8dbW" }, "source": [ "and, since it's low-dimensional, visualize it\n", "and see what we're asking the model to learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "33jcHbErbl6Q" }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", kind=\"scatter\");" ] }, { "cell_type": "markdown", "metadata": { "id": "pA7-4tJJ9fde" }, "source": [ "Now we're ready to run training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IY910O803oPU" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer.fit(model=model, train_dataloaders=tdl)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "sQBXYmLF_GoI" }, "source": [ "The loss after training should be less than the loss before training,\n", "and we can see that our model's predictions line up with the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jqcbA91x96-s" }, "outputs": [], "source": [ "ax = pd.DataFrame(data={\"x\": example_xs.flatten(), \"y\": example_ys.flatten()})\\\n", " .plot(x=\"x\", y=\"y\", legend=True, kind=\"scatter\", label=\"data\")\n", "\n", "inps = torch.arange(-2, 2, 0.5)[:, None]\n", "ax.plot(inps, model(inps).detach(), lw=2, color=\"k\", label=\"predictions\"); ax.legend();" ] }, { "cell_type": "markdown", "metadata": { "id": "gZkpsNfl3P8R" }, "source": [ "The `Trainer` promises to \"customize every aspect of training via flags\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_Q-c9b62_XFj" }, "outputs": [], "source": [ "pl.Trainer.__init__.__doc__.strip().split(\"\\n\")[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "He-zEwMB_oKH" }, "source": [ "and they mean _every_ aspect.\n", "\n", "The cell below prints all of the arguments for the `pl.Trainer` class --\n", "no need to memorize or even understand them all now,\n", "just skim it to see how many customization options there are:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8F_rRPL3lfPE" }, "outputs": [], "source": [ "print(pl.Trainer.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "4X8dGmR53kYU" }, "source": [ "It's probably easier to read them on the documentation website:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cqUj6MxRkppr" }, "outputs": [], "source": [ "trainer_docs_link = f\"https://pytorch-lightning.readthedocs.io/en/{version}/common/trainer.html\"\n", "trainer_docs_link" ] }, { "cell_type": "markdown", "metadata": { "id": "3T8XMYvr__Y5" }, "source": [ "# Training with PyTorch Lightning in the FSDL Codebase" ] }, { "cell_type": "markdown", "metadata": { "id": "_CtaPliTAxy3" }, "source": [ "The `LightningModule`s in the FSDL codebase\n", "are stored in the `lit_models` submodule of the `text_recognizer` module.\n", "\n", "For now, we've just got some basic models.\n", "We'll add more as we go." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NMe5z1RSAyo_" }, "outputs": [], "source": [ "!ls text_recognizer/lit_models" ] }, { "cell_type": "markdown", "metadata": { "id": "fZTYmIHbBu7g" }, "source": [ "We also have a folder called `training` now.\n", "\n", "This contains a script, `run_experiment.py`,\n", "that is used for running training jobs.\n", "\n", "In case you want to play around with the training code\n", "in a notebook, you can also load it as a module:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DRz9GbXzNJLM" }, "outputs": [], "source": [ "!ls training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Im9vLeyqBv_h" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.__doc__, training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "u2hcAXqHAV0v" }, "source": [ "We build the `Trainer` from command line arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yi50CDZul7Mm" }, "outputs": [], "source": [ "# how the trainer is initialized in the training script\n", "!grep \"pl.Trainer.from\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": { "id": "bZQheYJyAxlh" }, "source": [ "so all the configuration flexibility and complexity of the `Trainer`\n", "is available via the command line.\n", "\n", "Docs for the command line arguments for the trainer are accessible with `--help`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XlSmSyCMAw7Z" }, "outputs": [], "source": [ "# displays the first few flags for controlling the Trainer from the command line\n", "!python training/run_experiment.py --help | grep \"pl.Trainer\" -A 24" ] }, { "cell_type": "markdown", "metadata": { "id": "mIZ_VRPcNMsM" }, "source": [ "We'll use `run_experiment` in\n", "[Lab 02b](http://fsdl.me/lab02b-colab)\n", "to train convolutional neural networks." ] }, { "cell_type": "markdown", "metadata": { "id": "z0siaL4Qumc_" }, "source": [ "# Extra Goodies" ] }, { "cell_type": "markdown", "metadata": { "id": "PkQSPnxQDBF6" }, "source": [ "The `LightningModule` and the `Trainer` are the minimum amount you need\n", "to get started with PyTorch Lightning.\n", "\n", "But they aren't all you need.\n", "\n", "There are many more features built into Lightning and its ecosystem.\n", "\n", "We'll cover three more here:\n", "- `pl.LightningDataModule`s, for organizing dataloaders and handling data in distributed settings\n", "- `pl.Callback`s, for adding \"optional\" extra features to model training\n", "- `torchmetrics`, for efficiently computing and logging " ] }, { "cell_type": "markdown", "metadata": { "id": "GOYHSLw_D8Zy" }, "source": [ "## `pl.LightningDataModule`" ] }, { "cell_type": "markdown", "metadata": { "id": "rpjTNGzREIpl" }, "source": [ "Where the `LightningModule` organizes our model and its optimizers,\n", "the `LightningDataModule` organizes our dataloading code." ] }, { "cell_type": "markdown", "metadata": { "id": "i_KkQ0iOWKD7" }, "source": [ "The class-level docstring explains the concept\n", "behind the class well\n", "and lists the main methods to be over-ridden:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IFTWHdsFV5WG" }, "outputs": [], "source": [ "print(pl.LightningDataModule.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rLiacppGB9BB" }, "source": [ "Let's upgrade our `CorrelatedDataset` from a PyTorch `Dataset` to a `LightningDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m1d62iC6Xv1i" }, "outputs": [], "source": [ "import math\n", "\n", "\n", "class CorrelatedDataModule(pl.LightningDataModule):\n", "\n", " def __init__(self, size=10_000, train_frac=0.8, batch_size=32):\n", " super().__init__() # again, mandatory superclass init, as with torch.nn.Modules\n", "\n", " # set some constants, like the train/val split\n", " self.size = size\n", " self.train_frac, self.val_frac = train_frac, 1 - train_frac\n", " self.train_indices = list(range(math.floor(self.size * train_frac)))\n", " self.val_indices = list(range(self.train_indices[-1], self.size))\n", "\n", " # under the hood, we've still got a torch Dataset\n", " self.dataset = CorrelatedDataset(N=size)" ] }, { "cell_type": "markdown", "metadata": { "id": "qQf-jUYRCi3m" }, "source": [ "`LightningDataModule`s are designed to work in distributed settings,\n", "where operations that set state\n", "(e.g. writing to disk or attaching something to `self` that you want to access later)\n", "need to be handled with care.\n", "\n", "Getting data ready for training is often a very stateful operation,\n", "so the `LightningDataModule` provides two separate methods for it:\n", "one called `setup` that handles any state that needs to be set up in each copy of the module\n", "(here, splitting the data and adding it to `self`)\n", "and one called `prepare_data` that handles any state that only needs to be set up in each machine\n", "(for example, downloading data from storage and writing it to the local disk)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mttu--rHX70r" }, "outputs": [], "source": [ "def setup(self, stage=None): # prepares state that needs to be set for each GPU on each node\n", " if stage == \"fit\" or stage is None: # other stages: \"test\", \"predict\"\n", " self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)\n", " self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)\n", "\n", "def prepare_data(self): # prepares state that needs to be set once per node\n", " pass # but we don't have any \"node-level\" computations\n", "\n", "\n", "CorrelatedDataModule.setup, CorrelatedDataModule.prepare_data = setup, prepare_data" ] }, { "cell_type": "markdown", "metadata": { "id": "Rh3mZrjwD83Y" }, "source": [ "We then define methods to return `DataLoader`s when requested by the `Trainer`.\n", "\n", "To run a testing loop that uses a `LightningDataModule`,\n", "you'll also need to define a `test_dataloader`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu9Ma3iKYPBd" }, "outputs": [], "source": [ "def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.train_dataset, batch_size=32)\n", "\n", "def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " return torch.utils.data.DataLoader(self.val_dataset, batch_size=32)\n", "\n", "CorrelatedDataModule.train_dataloader, CorrelatedDataModule.val_dataloader = train_dataloader, val_dataloader" ] }, { "cell_type": "markdown", "metadata": { "id": "aNodiN6oawX5" }, "source": [ "Now we're ready to run training using a datamodule:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JKBwoE-Rajqw" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "print(\"loss before training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "trainer.fit(model=model, datamodule=datamodule)\n", "\n", "print(\"loss after training:\", torch.mean(torch.square(model(dataset.xs) - dataset.ys)).item())" ] }, { "cell_type": "markdown", "metadata": { "id": "Bw6flh5Jf2ZP" }, "source": [ "Notice the warning: \"`Skipping val loop.`\"\n", "\n", "It's being raised because our minimal `LinearRegression` model\n", "doesn't have a `.validation_step` method.\n", "\n", "In the exercises, you're invited to add a validation step and resolve this warning." ] }, { "cell_type": "markdown", "metadata": { "id": "rJnoFx47ZjBw" }, "source": [ "In the FSDL codebase,\n", "we define the basic functions of a `LightningDataModule`\n", "in the `BaseDataModule` and defer details to subclasses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PTPKvDDGXmOr" }, "outputs": [], "source": [ "from text_recognizer.data import BaseDataModule\n", "\n", "\n", "BaseDataModule??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3mRlZecwaKB4" }, "outputs": [], "source": [ "from text_recognizer.data.mnist import MNIST\n", "\n", "\n", "MNIST??" ] }, { "cell_type": "markdown", "metadata": { "id": "uQbMY08qD-hm" }, "source": [ "## `pl.Callback`" ] }, { "cell_type": "markdown", "metadata": { "id": "NVe7TSNvHK4K" }, "source": [ "Lightning's `Callback` class is used to add \"nice-to-have\" features\n", "to training, validation, and testing\n", "that aren't strictly necessary for any model to run\n", "but are useful for many models." ] }, { "cell_type": "markdown", "metadata": { "id": "RzU76wgFGw9N" }, "source": [ "A \"callback\" is a unit of code that's meant to be called later,\n", "based on some trigger.\n", "\n", "It's a very flexible system, which is why\n", "`Callback`s are used internally to implement lots of important Lightning features,\n", "including some we've already discussed, like `ModelCheckpoint` for saving during training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-msDjbKdHTxU" }, "outputs": [], "source": [ "pl.callbacks.__all__ # builtin Callbacks from Lightning" ] }, { "cell_type": "markdown", "metadata": { "id": "d6WRNXtHHkbM" }, "source": [ "The triggers, or \"hooks\", here, are specific points in the training, validation, and testing loop.\n", "\n", "The names of the hooks generally explain when the hook will be called,\n", "but you can always check the documentation for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3iHjjnU8Hvgg" }, "outputs": [], "source": [ "hooks = \", \".join([method for method in dir(pl.Callback) if method.startswith(\"on_\")])\n", "print(\"hooks:\", *textwrap.wrap(hooks, width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "2E2M7O2cGdj7" }, "source": [ "You can define your own `Callback` by inheriting from `pl.Callback`\n", "and over-riding one of the \"hook\" methods --\n", "much the same way that you define your own `LightningModule`\n", "by writing your own `.training_step` and `.configure_optimizers`.\n", "\n", "Let's define a silly `Callback` just to demonstrate the idea:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UodFQKAGEJlk" }, "outputs": [], "source": [ "class HelloWorldCallback(pl.Callback):\n", "\n", " def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the start of the training epoch!\")\n", "\n", " def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):\n", " print(\"👋 hello from the end of the validation epoch!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "MU7oIpyEGoaP" }, "source": [ "This callback will print a message whenever the training epoch starts\n", "and whenever the validation epoch ends.\n", "\n", "Different \"hooks\" have different information directly available.\n", "\n", "For example, you can directly access the batch information\n", "inside the `on_train_batch_start` and `on_train_batch_end` hooks:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U17Qo_i_GCya" }, "outputs": [], "source": [ "import random\n", "\n", "\n", "def on_train_batch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):\n", " if random.random() > 0.995:\n", " print(f\"👋 hello from inside the lucky batch, #{batch_idx}!\")\n", "\n", "\n", "HelloWorldCallback.on_train_batch_start = on_train_batch_start" ] }, { "cell_type": "markdown", "metadata": { "id": "LVKQXZOwQNGJ" }, "source": [ "We provide the callbacks when initializing the `Trainer`,\n", "then they are invoked during model fitting." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-XHXZ64-ETCz" }, "outputs": [], "source": [ "model = LinearRegression()\n", "\n", "datamodule = CorrelatedDataModule()\n", "\n", "trainer = pl.Trainer( # we instantiate and provide the callback here, but nothing happens yet\n", " max_epochs=10, gpus=int(torch.cuda.is_available()), callbacks=[HelloWorldCallback()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UEHUUhVOQv6K" }, "outputs": [], "source": [ "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "pP2Xj1woFGwG" }, "source": [ "You can read more about callbacks in the documentation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "COHk5BZvFJN_" }, "outputs": [], "source": [ "callback_docs_url = f\"https://pytorch-lightning.readthedocs.io/en/{version}/extensions/callbacks.html\"\n", "callback_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2K9e44iEGCR" }, "source": [ "## `torchmetrics`" ] }, { "cell_type": "markdown", "metadata": { "id": "dO-UIFKyJCqJ" }, "source": [ "DNNs are also finicky and break silently:\n", "rather than crashing, they just start doing the wrong thing.\n", "Without careful monitoring, that wrong thing can be invisible\n", "until long after it has done a lot of damage to you, your team, or your users.\n", "\n", "We want to calculate metrics so we can monitor what's happening during training and catch bugs --\n", "or even achieve [\"observability\"](https://thenewstack.io/observability-a-3-year-retrospective/),\n", "meaning we can also determine\n", "how to fix bugs in training just by viewing logs." ] }, { "cell_type": "markdown", "metadata": { "id": "z4YMyUI0Jr2f" }, "source": [ "But DNN training is also performance sensitive.\n", "Training runs for large language models have budgets that are\n", "more comparable to building an apartment complex\n", "than they are to the build jobs of traditional software pipelines.\n", "\n", "Slowing down training even a small amount can add a substantial dollar cost,\n", "obviating the benefits of catching and fixing bugs more quickly.\n", "\n", "Also implementing metric calculation during training adds extra work,\n", "much like the other software engineering best practices which it closely resembles,\n", "namely test-writing and monitoring.\n", "This distracts and detracts from higher-leverage research work." ] }, { "cell_type": "markdown", "metadata": { "id": "sbvWjiHSIxzM" }, "source": [ "\n", "The `torchmetrics` library, which began its life as `pytorch_lightning.metrics`,\n", "resolves these issues by providing a `Metric` class that\n", "incorporates best performance practices,\n", "like smart accumulation across batches and over devices,\n", "defines a unified interface,\n", "and integrates with Lightning's built-in logging." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "21y3lgvwEKPC" }, "outputs": [], "source": [ "import torchmetrics\n", "\n", "\n", "tm_version = torchmetrics.__version__\n", "print(\"metrics:\", *textwrap.wrap(\", \".join(torchmetrics.__all__), width=80), sep=\"\\n\\t\")" ] }, { "cell_type": "markdown", "metadata": { "id": "9TuPZkV1gfFE" }, "source": [ "Like the `LightningModule`, `torchmetrics.Metric` inherits from `torch.nn.Module`.\n", "\n", "That's because metric calculation, like module application, is typically\n", "1) an array-heavy computation that\n", "2) relies on persistent state\n", "(parameters for `Module`s, running values for `Metric`s) and\n", "3) benefits from acceleration and\n", "4) can be distributed over devices and nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "leiiI_QDS2_V" }, "outputs": [], "source": [ "issubclass(torchmetrics.Metric, torch.nn.Module)" ] }, { "cell_type": "markdown", "metadata": { "id": "Wy8MF2taP8MV" }, "source": [ "Documentation for the version of `torchmetrics` we're using can be found here:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LN4ashooP_tM" }, "outputs": [], "source": [ "torchmetrics_docs_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/\"\n", "torchmetrics_docs_url" ] }, { "cell_type": "markdown", "metadata": { "id": "5aycHhZNXwjr" }, "source": [ "In the `BaseLitModel`,\n", "we use the `torchmetrics.Accuracy` metric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Vyq4IjmBXzTv" }, "outputs": [], "source": [ "BaseLitModel.__init__??" ] }, { "cell_type": "markdown", "metadata": { "id": "KPoTH50YfkMF" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "hD_6PVAeflWw" }, "source": [ "### 🌟 Add a `validation_step` to the `LinearRegression` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5KKbAN9eK281" }, "outputs": [], "source": [ "def validation_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "\n", "LinearRegression.validation_step = validation_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AnPPHAPxFCEv" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "# if you code is working, you should see results for the validation loss in the output\n", "trainer.fit(model=model, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": { "id": "u42zXktOFDhZ" }, "source": [ "### 🌟🌟 Add a `test_step` to the `LinearRegression` class and a `test_dataloader` to the `CorrelatedDataModule`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cbWfqvumFESV" }, "outputs": [], "source": [ "def test_step(self: pl.LightningModule, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:\n", " pass # your code here\n", "\n", "LinearRegression.test_step = test_step" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pB96MpibLeJi" }, "outputs": [], "source": [ "class CorrelatedDataModuleWithTest(pl.LightningDataModule):\n", "\n", " def __init__(self, N=10_000, N_test=10_000): # reimplement __init__ here\n", " super().__init__() # don't forget this!\n", " self.dataset = None\n", " self.test_dataset = None # define a test set -- another sample from the same distribution\n", "\n", " def setup(self, stage=None):\n", " pass\n", "\n", " def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:\n", " pass # create a dataloader for the test set here" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1jq3dcugMMOu" }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModuleWithTest()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# we run testing without fitting here\n", "trainer.test(model=model, datamodule=datamodule) # if your code is working, you should see performance on the test set here" ] }, { "cell_type": "markdown", "metadata": { "id": "JHg4MKmJPla6" }, "source": [ "### 🌟🌟🌟 Make a version of the `LinearRegression` class that calculates the `ExplainedVariance` metric during training and validation." ] }, { "cell_type": "markdown", "metadata": { "id": "M_1AKGWRR2ai" }, "source": [ "The \"variance explained\" is a useful metric for comparing regression models --\n", "its values are interpretable and comparable across datasets, unlike raw loss values.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vLecK4CsQWKk" }, "source": [ "Read the \"TorchMetrics in PyTorch Lightning\" guide for details on how to\n", "add metrics and metric logging\n", "to a `LightningModule`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cWy0HyG4RYnX" }, "outputs": [], "source": [ "torchmetrics_guide_url = f\"https://torchmetrics.readthedocs.io/en/v{tm_version}/pages/lightning.html\"\n", "torchmetrics_guide_url" ] }, { "cell_type": "markdown", "metadata": { "id": "UoSQ3y6sSTvP" }, "source": [ "And check out the docs for `ExplainedVariance` to see how it's calculated:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpGuRK2FRHh1" }, "outputs": [], "source": [ "print(torchmetrics.ExplainedVariance.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "_EAtpWXrSVR1" }, "source": [ "You'll want to start the `LinearRegression` class over from scratch,\n", "since the `__init__` and `{training, validation, test}_step` methods need to be rewritten." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rGtWt3_5SYTn" }, "outputs": [], "source": [ "# your code here" ] }, { "cell_type": "markdown", "metadata": { "id": "oFWNr1SfS5-r" }, "source": [ "You can test your code by running fitting and testing.\n", "\n", "To see whether it's working,\n", "[call `self.log` inside the `_step` methods](https://torchmetrics.readthedocs.io/en/v0.7.1/pages/lightning.html)\n", "with the\n", "[keyword argument `prog_bar=True`](https://pytorch-lightning.readthedocs.io/en/1.6.1/api/pytorch_lightning.core.LightningModule.html#pytorch_lightning.core.LightningModule.log).\n", "You should see the explained variance show up in the output alongside the loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jse95DGCS6gR", "scrolled": false }, "outputs": [], "source": [ "model = LinearRegression()\n", "datamodule = CorrelatedDataModule()\n", "\n", "dataset = datamodule.dataset\n", "\n", "trainer = pl.Trainer(max_epochs=10, gpus=int(torch.cuda.is_available()))\n", "\n", "# if your code is working, you should see explained variance in the progress bar/logs\n", "trainer.fit(model=model, datamodule=datamodule)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab02a_lightning.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab08/notebooks/lab02b_cnn.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 02b: Training a CNN on Synthetic Handwriting Data" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Fundamental principles for building neural networks with convolutional components\n", "- How to use Lightning's training framework via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 2\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "T9HoYWZKtTE_" }, "source": [ "The most basic neural networks,\n", "multi-layer perceptrons,\n", "are built by alternating\n", "parameterized linear transformations\n", "with non-linear transformations.\n", "\n", "This combination is capable of expressing\n", "[functions of arbitrary complexity](http://neuralnetworksanddeeplearning.com/chap4.html),\n", "so long as those functions\n", "take in fixed-size arrays and return fixed-size arrays.\n", "\n", "```python\n", "def any_function_you_can_imagine(x: torch.Tensor[\"A\"]) -> torch.Tensor[\"B\"]:\n", " return some_mlp_that_might_be_impractically_huge(x)\n", "```\n", "\n", "But not all functions have that type signature.\n", "\n", "For example, we might want to identify the content of images\n", "that have different sizes.\n", "Without gross hacks,\n", "an MLP won't be able to solve this problem,\n", "even though it seems simple enough." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6LjfV3o6tTFA" }, "outputs": [], "source": [ "import random\n", "\n", "import IPython.display as display\n", "\n", "randsize = 10 ** (random.random() * 2 + 1)\n", "\n", "Url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/emnist/U.png\"\n", "\n", "# run multiple times to display the same image at different sizes\n", "# the content of the image remains unambiguous\n", "display.Image(url=Url, width=randsize, height=randsize)" ] }, { "cell_type": "markdown", "metadata": { "id": "c9j6YQRftTFB" }, "source": [ "Even worse, MLPs are too general to be efficient.\n", "\n", "Each layer applies an unstructured matrix to its inputs.\n", "But most of the data we might want to apply them to is highly structured,\n", "and taking advantage of that structure can make our models more efficient.\n", "\n", "It may seem appealing to use an unstructured model:\n", "it can in principle learn any function.\n", "But\n", "[most functions are monstrous outrages against common sense](https://en.wikipedia.org/wiki/Weierstrass_function#Density_of_nowhere-differentiable_functions).\n", "It is useful to encode some of our assumptions\n", "about the kinds of functions we might want to learn\n", "from our data into our model's architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "jvC_yZvmuwgJ" }, "source": [ "## Convolutions are the local, translation-equivariant linear transforms." ] }, { "cell_type": "markdown", "metadata": { "id": "PhnRx_BZtTFC" }, "source": [ "One of the most common types of structure in data is \"locality\" --\n", "the most relevant information for understanding or predicting a pixel\n", "is a small number of pixels around it.\n", "\n", "Locality is a fundamental feature of the physical world,\n", "so it shows up in data drawn from physical observations,\n", "like photographs and audio recordings.\n", "\n", "Locality means most meaningful linear transformations of our input\n", "only have large weights in a small number of entries that are close to one another,\n", "rather than having equally large weights in all entries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SSnkzV2_tTFC" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "generic_linear_transform = torch.randn(8, 1)\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "local_linear_transform = torch.tensor([\n", " [0, 0, 0] + [random.random(), random.random(), random.random()] + [0, 0]]).T\n", "print(\"local:\", local_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "0nCD75NwtTFD" }, "source": [ "Another type of structure commonly observed is \"translation equivariance\" --\n", "the top-left pixel position is not, in itself, meaningfully different\n", "from the bottom-right position\n", "or a position in the middle of the image.\n", "Relative relationships matter more than absolute relationships.\n", "\n", "Translation equivariance arises in images because there is generally no privileged\n", "vantage point for taking the image.\n", "We could just as easily have taken the image while standing a few feet to the left or right,\n", "and all of its contents would shift along with our change in perspective.\n", "\n", "Translation equivariance means that a linear transformation that is meaningful at one position\n", "in our input is likely to be meaningful at all other points.\n", "We can learn something about a linear transformation from a datapoint where it is useful\n", "in the bottom-left and then apply it to another datapoint where it's useful in the top-right." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "srvI7JFAtTFE" }, "outputs": [], "source": [ "generic_linear_transform = torch.arange(8)[:, None]\n", "print(\"generic:\", generic_linear_transform, sep=\"\\n\")\n", "\n", "equivariant_linear_transform = torch.stack([torch.roll(generic_linear_transform[:, 0], ii) for ii in range(8)], dim=1)\n", "print(\"translation invariant:\", equivariant_linear_transform, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qF576NCvtTFE" }, "source": [ "A linear transformation that is translation equivariant\n", "[is called a _convolution_](https://en.wikipedia.org/wiki/Convolution#Translational_equivariance).\n", "\n", "If the weights of that linear transformation are mostly zero\n", "except for a few that are close to one another,\n", "that convolution is said to have a _kernel_." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9tp4tBgWtTFF" }, "outputs": [], "source": [ "# the equivalent of torch.nn.Linear, but for a 1-dimensional convolution\n", "conv_layer = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)\n", "\n", "conv_layer.weight # aka kernel" ] }, { "cell_type": "markdown", "metadata": { "id": "deXA_xS6tTFF" }, "source": [ "Instead of using normal matrix multiplication to apply the kernel to the input,\n", "we repeatedly apply that kernel over and over again,\n", "\"sliding\" it over the input to produce an output.\n", "\n", "Every convolution kernel has an equivalent matrix form,\n", "which can be matrix multiplied with the input to create the output:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFoSsa5DtTFF" }, "outputs": [], "source": [ "conv_kernel_as_vector = torch.hstack([conv_layer.weight[0][0], torch.zeros(5)])\n", "conv_layer_as_matrix = torch.stack([torch.roll(conv_kernel_as_vector, ii) for ii in range(8)], dim=0)\n", "print(\"convolution matrix:\", conv_layer_as_matrix, sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "VJyRtf9NtTFG" }, "source": [ "> Under the hood, the actual operation that implements the application of a convolutional kernel\n", "need not look like either of these\n", "(common approaches include\n", "[Winograd-type algorithms](https://arxiv.org/abs/1509.09308)\n", "and [Fast Fourier Transform-based algorithms](https://arxiv.org/abs/1312.5851))." ] }, { "cell_type": "markdown", "metadata": { "id": "xytivdcItTFG" }, "source": [ "Though they may seem somewhat arbitrary and technical,\n", "convolutions are actually a deep and fundamental piece of mathematics and computer science.\n", "Fundamental as in\n", "[closely related to the multiplication algorithm we learn as children](https://charlesfrye.github.io/math/2019/02/20/multiplication-convoluted-part-one.html)\n", "and deep as in\n", "[closely related to the Fourier transform](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution).\n", "Generalized convolutions can show up\n", "wherever there is some kind of \"sum\" over some kind of \"paths\",\n", "as is common in dynamic programming.\n", "\n", "In the context of this course,\n", "we don't have time to dive much deeper on convolutions or convolutional neural networks.\n", "\n", "See Chris Olah's blog series\n", "([1](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/),\n", "[2](https://colah.github.io/posts/2014-07-Understanding-Convolutions/),\n", "[3](https://colah.github.io/posts/2014-12-Groups-Convolution/))\n", "for a friendly introduction to the mathematical view of convolution.\n", "\n", "For more on convolutional neural network architectures, see\n", "[the lecture notes from Stanford's 2020 \"Deep Learning for Computer Vision\" course](https://cs231n.github.io/convolutional-networks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCJTwCWYzRee" }, "source": [ "## We apply two-dimensional convolutions to images." ] }, { "cell_type": "markdown", "metadata": { "id": "a8RKOPAIx0O2" }, "source": [ "In building our text recognizer,\n", "we're working with images.\n", "Images have two dimensions of translation equivariance:\n", "left/right and up/down.\n", "So we use two-dimensional convolutions,\n", "instantiated in `torch.nn` as `nn.Conv2d` layers.\n", "Note that convolutional neural networks for images\n", "are so popular that when the term \"convolution\"\n", "is used without qualifier in a neural network context,\n", "it can be taken to mean two-dimensional convolutions.\n", "\n", "Where `Linear` layers took in batches of vectors of a fixed size\n", "and returned batches of vectors of a fixed size,\n", "`Conv2d` layers take in batches of two-dimensional _stacked feature maps_\n", "and return batches of two-dimensional stacked feature maps.\n", "\n", "A pseudocode type signature based on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping)\n", "might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "sJvMdHL7w_lu" }, "source": [ "```python\n", "StackedFeatureMapIn = torch.Tensor[\"batch\", \"in_channels\", \"in_height\", \"in_width\"]\n", "StackedFeatureMapOut = torch.Tensor[\"batch\", \"out_channels\", \"out_height\", \"out_width\"]\n", "def same_convolution_2d(x: StackedFeatureMapIn) -> StackedFeatureMapOut:\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nSMC8Fw3zPSz" }, "source": [ "Here, \"map\" is meant to evoke space:\n", "our feature maps tell us where\n", "features are spatially located.\n", "\n", "An RGB image is a stacked feature map.\n", "It is composed of three feature maps.\n", "The first tells us where the \"red\" feature is present,\n", "the second \"green\", the third \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jIXT-mym3ljt" }, "outputs": [], "source": [ "display.Image(\n", " url=\"https://upload.wikimedia.org/wikipedia/commons/5/56/RGB_channels_separation.png?20110219015028\")" ] }, { "cell_type": "markdown", "metadata": { "id": "8WfCcO5xJ-hG" }, "source": [ "When we apply a convolutional layer to a stacked feature map with some number of channels,\n", "we get back a stacked feature map with some number of channels.\n", "\n", "This output is also a stack of feature maps,\n", "and so it is a perfectly acceptable\n", "input to another convolutional layer.\n", "That means we can compose convolutional layers together,\n", "just as we composed generic linear layers together.\n", "We again weave non-linear functions in between our linear convolutions,\n", "creating a _convolutional neural network_, or CNN." ] }, { "cell_type": "markdown", "metadata": { "id": "R18TsGubJ_my" }, "source": [ "## Convolutional neural networks build up visual understanding layer by layer." ] }, { "cell_type": "markdown", "metadata": { "id": "eV03KmYBz2QM" }, "source": [ "What is the equivalent of the labels, red/green/blue,\n", "for the channels in these feature maps?\n", "What does a high activation in some position in channel 32\n", "of the fifteenth layer of my network tell me?\n", "\n", "There is no guaranteed way to automatically determine the answer,\n", "nor is there a guarantee that the result is human-interpretable.\n", "OpenAI's Clarity team spent several years \"reverse engineering\"\n", "state-of-the-art convolutiuonal neural networks trained on photographs\n", "and found that many of these channels are\n", "[directly interpretable](https://distill.pub/2018/building-blocks/).\n", "\n", "For example, they found that if they pass an image through\n", "[GoogLeNet](https://doi.org/10.1109/cvpr.2015.7298594),\n", "aka InceptionV1,\n", "the winner of the\n", "[2014 ImageNet Very Large Scale Visual Recognition Challenge](https://www.image-net.org/challenges/LSVRC/2014/)," ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "64KJR70q6dCh" }, "outputs": [], "source": [ "# a sample image\n", "display.Image(url=\"https://distill.pub/2018/building-blocks/examples/input_images/dog_cat.jpeg\")" ] }, { "cell_type": "markdown", "metadata": { "id": "hJ7CvvG78CZ5" }, "source": [ "the features become increasingly complex,\n", "with channels in early layers (left)\n", "acting as maps for simple things like \"high frequency power\" or \"45 degree black-white edge\"\n", "and channels in later layers (to right)\n", "acting as feature maps for increasingly abstract concepts,\n", "like \"circle\" and eventually \"floppy round ear\" or \"pointy ear\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6w5_RR8d9jEY" }, "outputs": [], "source": [ "# from https://distill.pub/2018/building-blocks/\n", "display.Image(url=\"https://fsdl-public-assets.s3.us-west-2.amazonaws.com/distill-feature-attrib.png\", width=1024)" ] }, { "cell_type": "markdown", "metadata": { "id": "HLiqEwMY_Co0" }, "source": [ "> The small square images depict a heuristic estimate\n", "of what the entire collection of feature maps\n", "at a given layer represent (layer IDs at bottom).\n", "They are arranged in a spatial grid and their sizes represent\n", "the total magnitude of the layer's activations at that position.\n", "For details and interactivity, see\n", "[the original Distill article](https://distill.pub/2018/building-blocks/)." ] }, { "cell_type": "markdown", "metadata": { "id": "vl8XlEsaA54W" }, "source": [ "In the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "blogpost series,\n", "the Open AI Clarity team\n", "combines careful examination of weights\n", "with direct experimentation\n", "to build an understanding of how these higher-level features\n", "are constructed in GoogLeNet.\n", "\n", "For example,\n", "they are able to provide reasonable interpretations for\n", "[almost every channel in the first five layers](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "The cell below will pull down their \"weight explorer\"\n", "and embed it in this notebook.\n", "By default, it starts on\n", "[the 52nd channel in the `conv2d1` layer](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d1_52.html),\n", "which constructs a large, phase-invariant\n", "[Gabor filter](https://en.wikipedia.org/wiki/Gabor_filter)\n", "from smaller, phase-sensitive filters.\n", "It is in turn used to construct\n", "[curve](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_180.html)\n", "and\n", "[texture](https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/conv2d2_114.html)\n", "detectors --\n", "click on any image to navigate to the weight explorer page\n", "for that channel\n", "or change the `layer` and `idx`\n", "arguments.\n", "For additional context,\n", "check out the\n", "[Early Vision in InceptionV1 blogpost](https://distill.pub/2020/circuits/early-vision/).\n", "\n", "Click the \"View this neuron in the OpenAI Microscope\" link\n", "for an even richer interactive view,\n", "including activations on sample images\n", "([example](https://microscope.openai.com/models/inceptionv1/conv2d1_0/52)).\n", "\n", "The\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "which this explorer accompanies\n", "is chock-full of empirical observations, theoretical speculation, and nuggets of wisdom\n", "that are invaluable for developing intuition about both\n", "convolutional networks in particular and visual perception in general." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "I4-hkYjdB-qQ" }, "outputs": [], "source": [ "layers = [\"conv2d0\", \"conv2d1\", \"conv2d2\", \"mixed3a\", \"mixed3b\"]\n", "layer = layers[1]\n", "idx = 52\n", "\n", "weight_explorer = display.IFrame(\n", " src=f\"https://storage.googleapis.com/distill-circuits/inceptionv1-weight-explorer/{layer}_{idx}.html\", width=1024, height=720)\n", "weight_explorer.iframe = 'style=\"background: #FFF\";\\n><'.join(weight_explorer.iframe.split(\"><\")) # inject background color\n", "weight_explorer" ] }, { "cell_type": "markdown", "metadata": { "id": "NJ6_PCmVtTFH" }, "source": [ "# Applying convolutions to handwritten characters: `CNN`s on `EMNIST`" ] }, { "cell_type": "markdown", "metadata": { "id": "N--VkRtR5Yr-" }, "source": [ "If we load up the `CNN` class from `text_recognizer.models`,\n", "we'll see that a `data_config` is required to instantiate the model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N3MA--zytTFH" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "text_recognizer.models.CNN??" ] }, { "cell_type": "markdown", "metadata": { "id": "7yCP46PO6XDg" }, "source": [ "So before we can make our convolutional network and train it,\n", "we'll need to get a hold of some data.\n", "This isn't a general constraint by the way --\n", "it's an implementation detail of the `text_recognizer` library.\n", "But datasets and models are generally coupled,\n", "so it's common for them to share configuration information." ] }, { "cell_type": "markdown", "metadata": { "id": "6Z42K-jjtTFH" }, "source": [ "## The `EMNIST` Handwritten Character Dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "oiifKuu4tTFH" }, "source": [ "We could just use `MNIST` here,\n", "as we did in\n", "[the first lab](https://fsdl.me/lab01-colab).\n", "\n", "But we're aiming to eventually build a handwritten text recognition system,\n", "which means we need to handle letters and punctuation,\n", "not just numbers.\n", "\n", "So we instead use _EMNIST_,\n", "or [Extended MNIST](https://paperswithcode.com/paper/emnist-an-extension-of-mnist-to-handwritten),\n", "which includes letters and punctuation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ePZW1Tfa00K" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist = text_recognizer.data.EMNIST() # configure\n", "print(emnist.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "D_yjBYhla6qp" }, "source": [ "We've built a PyTorch Lightning `DataModule`\n", "to encapsulate all the code needed to get this dataset ready to go:\n", "downloading to disk,\n", "[reformatting to make loading faster](https://www.h5py.org/),\n", "and splitting into training, validation, and test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ty2vakBBtTFI" }, "outputs": [], "source": [ "emnist.prepare_data() # download, save to disk\n", "emnist.setup() # create torch.utils.data.Datasets, do train/val split" ] }, { "cell_type": "markdown", "metadata": { "id": "5h9bAXcu8l5J" }, "source": [ "A brief aside: you might be wondering where this data goes.\n", "Datasets are saved to disk inside the repo folder,\n", "but not tracked in version control.\n", "`git` works well for versioning source code\n", "and other text files, but it's a poor fit for large binary data.\n", "We only track and version metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E5cwDCM88SnU" }, "outputs": [], "source": [ "!echo {emnist.data_dirname()}\n", "!ls {emnist.data_dirname()}\n", "!ls {emnist.data_dirname() / \"raw\" / \"emnist\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "IdsIBL9MtTFI" }, "source": [ "This class comes with a pretty printing method\n", "for quick examination of some of that metadata and basic descriptive statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cyw66d6GtTFI" }, "outputs": [], "source": [ "emnist" ] }, { "cell_type": "markdown", "metadata": { "id": "QT0burlOLgoH" }, "source": [ "\n", "> You can add pretty printing to your own Python classes by writing\n", "`__str__` or `__repr__` methods for them.\n", "The former is generally expected to be human-readable,\n", "while the latter is generally expected to be machine-readable;\n", "we've broken with that custom here and used `__repr__`. " ] }, { "cell_type": "markdown", "metadata": { "id": "XJF3G5idtTFI" }, "source": [ "Because we've run `.prepare_data` and `.setup`,\n", "we can expect that this `DataModule` is ready to provide a `DataLoader`\n", "if we invoke the right method --\n", "sticking to the PyTorch Lightning API brings these kinds of convenient guarantees\n", "even when we're not using the `Trainer` class itself,\n", "[as described in Lab 2a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XJghcZkWtTFI" }, "outputs": [], "source": [ "xs, ys = next(iter(emnist.train_dataloader()))" ] }, { "cell_type": "markdown", "metadata": { "id": "40FWjMT-tTFJ" }, "source": [ "Run the cell below to inspect random elements of this batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0hywyEI_tTFJ" }, "outputs": [], "source": [ "import wandb\n", "\n", "idx = random.randint(0, len(xs) - 1)\n", "\n", "print(emnist.mapping[ys[idx]])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "hdg_wYWntTFJ" }, "source": [ "## Putting convolutions in a `torch.nn.Module`" ] }, { "cell_type": "markdown", "metadata": { "id": "JGuSx_zvtTFJ" }, "source": [ "Because we have the data,\n", "we now have a `data_config`\n", "and can instantiate the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rxLf7-5jtTFJ" }, "outputs": [], "source": [ "data_config = emnist.config()\n", "\n", "cnn = text_recognizer.models.CNN(data_config)\n", "cnn # reveals the nn.Modules attached to our nn.Module" ] }, { "cell_type": "markdown", "metadata": { "id": "jkeJNVnIMVzJ" }, "source": [ "We can run this network on our inputs,\n", "but we don't expect it to produce correct outputs without training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4EwujOGqMAZY" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = cnn(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "P3L8u0estTFJ" }, "source": [ "We can inspect the `.forward` method to see how these `nn.Module`s are used.\n", "\n", "> Note: we encourage you to read through the code --\n", "either inside the notebooks, as below,\n", "in your favorite text editor locally, or\n", "[on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs).\n", "There's lots of useful bits of Python that we don't have time to cover explicitly in the labs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RtA0W8jvtTFJ" }, "outputs": [], "source": [ "cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "VCycQ88gtTFK" }, "source": [ "We apply convolutions followed by non-linearities,\n", "with intermittent \"pooling\" layers that apply downsampling --\n", "similar to the 1989\n", "[LeNet](https://doi.org/10.1162%2Fneco.1989.1.4.541)\n", "architecture or the 2012\n", "[AlexNet](https://doi.org/10.1145%2F3065386)\n", "architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "qkGJCnMttTFK" }, "source": [ "The final classification is performed by an MLP.\n", "\n", "In order to get vectors to pass into that MLP,\n", "we first apply `torch.flatten`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WZPhw7ufAKZ7" }, "outputs": [], "source": [ "torch.flatten(torch.Tensor([[1, 2], [3, 4]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "jCoCa3vCNM8j" }, "source": [ "## Design considerations for CNNs" ] }, { "cell_type": "markdown", "metadata": { "id": "dDLEMnPINTj7" }, "source": [ "Since the release of AlexNet,\n", "there has been a feverish decade of engineering and innovation in CNNs --\n", "[dilated convolutions](https://arxiv.org/abs/1511.07122),\n", "[residual connections](https://arxiv.org/abs/1512.03385), and\n", "[batch normalization](https://arxiv.org/abs/1502.03167)\n", "came out in 2015 alone, and\n", "[work continues](https://arxiv.org/abs/2201.03545) --\n", "so we can only scratch the surface in this course and\n", "[the devil is in the details](https://arxiv.org/abs/1405.3531v4).\n", "\n", "The progress of DNNs in general and CNNs in particular\n", "has been mostly evolutionary,\n", "with lots of good ideas that didn't work out\n", "and weird hacks that stuck around because they did.\n", "That can make it very hard to design a fresh architecture\n", "from first principles that's anywhere near as effective as existing architectures.\n", "You're better off tweaking and mutating an existing architecture\n", "than trying to design one yourself.\n", "\n", "If you're not keeping close tabs on the field,\n", "when your first start looking for an architecture to base your work off of\n", "it's best to go to trusted aggregators, like\n", "[Torch IMage Models](https://github.com/rwightman/pytorch-image-models),\n", "or `timm`, on GitHub, or\n", "[Papers With Code](https://paperswithcode.com),\n", "specifically the section for\n", "[computer vision](https://paperswithcode.com/methods/area/computer-vision).\n", "You can also take a more bottom-up approach by checking\n", "the leaderboards of the latest\n", "[Kaggle competitions on computer vision](https://www.kaggle.com/competitions?searchQuery=computer+vision).\n", "\n", "We'll briefly touch here on some of the main design considerations\n", "with classic CNN architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "nd0OeyouDNlS" }, "source": [ "### Shapes and padding" ] }, { "cell_type": "markdown", "metadata": { "id": "5w3p8QP6AnGQ" }, "source": [ "In the `.forward` pass of the `CNN`,\n", "we've included comments that indicate the expected shapes\n", "of tensors after each line that changes the shape.\n", "\n", "Tracking and correctly handling shapes is one of the bugbears\n", "of CNNs, especially architectures,\n", "like LeNet/AlexNet, that include MLP components\n", "that can only operate on fixed-shape tensors." ] }, { "cell_type": "markdown", "metadata": { "id": "vgbM30jstTFK" }, "source": [ "[Shape arithmetic gets pretty hairy pretty fast](https://arxiv.org/abs/1603.07285)\n", "if you're supporting the wide variety of convolutions.\n", "\n", "The easiest way to avoid shape bugs is to keep things simple:\n", "choose your convolution parameters,\n", "like `padding` and `stride`,\n", "to keep the shape the same before and after\n", "the convolution.\n", "\n", "That's what we do, by choosing `padding=1`\n", "for `kernel_size=3` and `stride=1`.\n", "With unit strides and odd-numbered kernel size,\n", "the padding that keeps\n", "the input the same size is `kernel_size // 2`.\n", "\n", "As shapes change, so does the amount of GPU memory taken up by the tensors.\n", "Keeping sizes fixed within a block removes one axis of variation\n", "in the demands on an important resource.\n", "\n", "After applying our pooling layer,\n", "we can just increase the number of kernels by the right factor\n", "to keep total tensor size,\n", "and thus memory footprint, constant." ] }, { "cell_type": "markdown", "metadata": { "id": "2BCkTZGSDSBG" }, "source": [ "### Parameters, computation, and bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "pZbgm7wztTFK" }, "source": [ "If we review the `num`ber of `el`ements in each of the layers,\n", "we see that one layer has far more entries than all the others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nfjPVwztTFK" }, "outputs": [], "source": [ "[p.numel() for p in cnn.parameters()] # conv weight + bias, conv weight + bias, fc weight + bias, fc weight + bias" ] }, { "cell_type": "markdown", "metadata": { "id": "DzIoCz1FtTFK" }, "source": [ "The biggest layer is typically\n", "the one in between the convolutional component\n", "and the MLP component:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYrlUprltTFK" }, "outputs": [], "source": [ "biggest_layer = [p for p in cnn.parameters() if p.numel() == max(p.numel() for p in cnn.parameters())][0]\n", "biggest_layer.shape, cnn.fc_input_dim" ] }, { "cell_type": "markdown", "metadata": { "id": "HSHdvEGptTFL" }, "source": [ "This layer dominates the cost of storing the network on disk.\n", "That makes it a common target for\n", "regularization techniques like DropOut\n", "(as in our architecture)\n", "and performance optimizations like\n", "[pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html).\n", "\n", "Heuristically, we often associated more parameters with more computation.\n", "But just because that layer has the most parameters\n", "does not mean that most of the compute time is spent in that layer.\n", "\n", "Convolutions reuse the same parameters over and over,\n", "so the total number of FLOPs done by the layer can be higher\n", "than that done by layers with more parameters --\n", "much higher." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YLisj1SptTFL" }, "outputs": [], "source": [ "# for the Linear layers, number of multiplications per input == nparams\n", "cnn.fc1.weight.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Yo2oINHRtTFL" }, "outputs": [], "source": [ "# for the Conv2D layers, it's more complicated\n", "\n", "def approx_conv_multiplications(kernel_shape, input_size=(32, 28, 28)): # this is a rough and dirty approximation\n", " num_kernels, input_channels, kernel_height, kernel_width = kernel_shape\n", " input_height, input_width = input_size[1], input_size[2]\n", "\n", " multiplications_per_kernel_application = input_channels * kernel_height * kernel_width\n", " num_applications = ((input_height - kernel_height + 1) * (input_width - kernel_width + 1))\n", " mutliplications_per_kernel = num_applications * multiplications_per_kernel_application\n", "\n", " return mutliplications_per_kernel * num_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LwCbZU9PtTFL" }, "outputs": [], "source": [ "approx_conv_multiplications(cnn.conv2.conv.weight.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sdco4m9UtTFL" }, "outputs": [], "source": [ "# ratio of multiplications in the convolution to multiplications in the fully-connected layer is large!\n", "approx_conv_multiplications(cnn.conv2.conv.weight.shape) // cnn.fc1.weight.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "joVoBEtqtTFL" }, "source": [ "Depending on your compute hardware and the problem characteristics,\n", "either the MLP component or the convolutional component\n", "could become the critical bottleneck.\n", "\n", "When you're memory constrained, like when transferring a model \"over the wire\" to a browser,\n", "the MLP component is likely to be the bottleneck,\n", "whereas when you are compute-constrained, like when running a model on a low-power edge device\n", "or in an application with strict low-latency requirements,\n", "the convolutional component is likely to be the bottleneck.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "pGSyp67dtTFM" }, "source": [ "## Training a `CNN` on `EMNIST` with the Lightning `Trainer` and `run_experiment`" ] }, { "cell_type": "markdown", "metadata": { "id": "AYTJs7snQfX0" }, "source": [ "We have a model and we have data,\n", "so we could just go ahead and start training in raw PyTorch,\n", "[as we did in Lab 01](https://fsdl.me/lab01-colab).\n", "\n", "But as we saw in that lab,\n", "there are good reasons to use a framework\n", "to organize training and provide fixed interfaces and abstractions.\n", "So we're going to use PyTorch Lightning, which is\n", "[covered in detail in Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "hZYaJ4bdMcWc" }, "source": [ "We provide a simple script that implements a command line interface\n", "to training with PyTorch Lightning\n", "using the models and datasets in this repository:\n", "`training/run_experiment.py`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52kIYhPBPLNZ" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "rkM_HpILSyC9" }, "source": [ "The `pl.Trainer` arguments come first\n", "and there\n", "[are a lot of them](https://pytorch-lightning.readthedocs.io/en/1.6.3/common/trainer.html),\n", "so if we want to see what's configurable for\n", "our `Model` or our `LitModel`,\n", "we want the last few dozen lines of the help message:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G0dBhgogO8_A" }, "outputs": [], "source": [ "!python training/run_experiment.py --help --model_class CNN --data_class EMNIST | tail -n 25" ] }, { "cell_type": "markdown", "metadata": { "id": "NCBQekrPRt90" }, "source": [ "The `run_experiment.py` file is also importable as a module,\n", "so that you can inspect its contents\n", "and play with its component functions in a notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CPumvYatPaiS" }, "outputs": [], "source": [ "import training.run_experiment\n", "\n", "\n", "print(training.run_experiment.main.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "YiZ3RwW2UzJm" }, "source": [ "Let's run training!\n", "\n", "Execute the cell below to launch a training job for a CNN on EMNIST with default arguments.\n", "\n", "This will take several minutes on commodity hardware,\n", "so feel free to keep reading while it runs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5RSJM5I2TSeG", "scrolled": true }, "outputs": [], "source": [ "gpus = int(torch.cuda.is_available()) # use GPUs if they're available\n", "\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus}" ] }, { "cell_type": "markdown", "metadata": { "id": "_ayQ4ByJOnnP" }, "source": [ "The first thing you'll see are a few logger messages from Lightning,\n", "then some info about the hardware you have available and are using." ] }, { "cell_type": "markdown", "metadata": { "id": "VcMrZcecO1EF" }, "source": [ "Then you'll see a summary of your model,\n", "including module names, parameter counts,\n", "and information about model disk size.\n", "\n", "`torchmetrics` show up here as well,\n", "since they are also `nn.Module`s.\n", "See [Lab 02a](https://fsdl.me/lab02a-colab)\n", "for details.\n", "We're tracking accuracy on training, validation, and test sets." ] }, { "cell_type": "markdown", "metadata": { "id": "twGp9iWOUSfc" }, "source": [ "You may also see a quick message in the terminal\n", "referencing a \"validation sanity check\".\n", "PyTorch Lightning runs a few batches of validation data\n", "through the model before the first training epoch.\n", "This helps prevent training runs from crashing\n", "at the end of the first epoch,\n", "which is otherwise the first time validation loops are triggered\n", "and is sometimes hours into training,\n", "by crashing them quickly at the start.\n", "\n", "If you want to turn off the check,\n", "use `--num_sanity_val_steps=0`." ] }, { "cell_type": "markdown", "metadata": { "id": "jnKN3_MiRpE4" }, "source": [ "Then, you'll see a bar indicating\n", "progress through the training epoch,\n", "alongside metrics like throughput and loss.\n", "\n", "When the first (and only) epoch ends,\n", "the model is run on the validation set\n", "and aggregate loss and accuracy are reported to the console." ] }, { "cell_type": "markdown", "metadata": { "id": "R2eMZz_HR8vV" }, "source": [ "At the end of training,\n", "we call `Trainer.test`\n", "to check performance on the test set.\n", "\n", "We typically see test accuracy around 75-80%." ] }, { "cell_type": "markdown", "metadata": { "id": "ybpLiKBKSDXI" }, "source": [ "During training, PyTorch Lightning saves _checkpoints_\n", "(file extension `.ckpt`)\n", "that can be used to restart training.\n", "\n", "The final line output by `run_experiment`\n", "indicates where the model with the best performance\n", "on the validation set has been saved.\n", "\n", "The checkpointing behavior is configured using a\n", "[`ModelCheckpoint` callback](https://pytorch-lightning.readthedocs.io/en/1.6.3/api/pytorch_lightning.callbacks.ModelCheckpoint.html).\n", "The `run_experiment` script picks sensible defaults.\n", "\n", "These checkpoints contain the model weights.\n", "We can use them to los the model in the notebook and play around with it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3Rqh9ZQsY8g4" }, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest checkpoint's filename\n", "# by hand, you can just copy and paste it\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs\" # find avoids issues with \\n in filenames\n", "filter_to_ckpts = \"grep \\.ckpt$\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "latest_ckpt" ] }, { "cell_type": "markdown", "metadata": { "id": "7QW_CxR3coV6" }, "source": [ "To rebuild the model,\n", "we need to consider some implementation details of the `run_experiment` script.\n", "\n", "We use the parsed command line arguments, the `args`, to build the data and model,\n", "then use all three to build the `LightningModule`.\n", "\n", "Any `LightningModule` can be reinstantiated from a checkpoint\n", "using the `load_from_checkpoint` method,\n", "but we'll need to recreate and pass the `args`\n", "in order to reload the model.\n", "(We'll see how this can be automated later)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oVWEHcgvaSqZ" }, "outputs": [], "source": [ "import training.util\n", "from argparse import Namespace\n", "\n", "\n", "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"CNN\",\n", " \"data_class\": \"EMNIST\"})\n", "\n", "\n", "_, cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "reloaded_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=cnn)" ] }, { "cell_type": "markdown", "metadata": { "id": "MynyI_eUcixa" }, "source": [ "With the model reloads, we can run it on some sample data\n", "and see how it's doing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L0HCxgVwcRAA" }, "outputs": [], "source": [ "idx = random.randint(0, len(xs) - 1)\n", "outs = reloaded_model(xs[idx:idx+1])\n", "\n", "print(\"output:\", emnist.mapping[torch.argmax(outs)])\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "G6NtaHuVdfqt" }, "source": [ "I generally see subjectively good performance --\n", "without seeing the labels, I tend to agree with the model's output\n", "more often than the accuracy would suggest,\n", "since some classes, like c and C or o, O, and 0,\n", "are essentially indistinguishable." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZzcDcxpVkki" }, "source": [ "We can continue a promising training run from the checkpoint.\n", "Run the cell below to train the model just trained above\n", "for another epoch.\n", "Note that the training loss starts out close to where it ended\n", "in the previous run.\n", "\n", "Paired with cloud storage of checkpoints,\n", "this makes it possible to use\n", "[a cheaper type of cloud instance](https://cloud.google.com/blog/products/ai-machine-learning/reduce-the-costs-of-ml-workflows-with-preemptible-vms-and-gpus)\n", "that can be pre-empted by someone willing to pay more,\n", "which terminates your job.\n", "It's also helpful when using Google Colab for more serious projects --\n", "your training runs are no longer bound by the maximum uptime of a Colab notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "skqdikNtVnaf" }, "outputs": [], "source": [ "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "\n", "\n", "# and we can change the training hyperparameters, like batch size\n", "%run training/run_experiment.py --model_class CNN --data_class EMNIST --gpus {gpus} \\\n", " --batch_size 64 --load_checkpoint {latest_ckpt}" ] }, { "cell_type": "markdown", "metadata": { "id": "HBdNt6Z2tTFM" }, "source": [ "# Creating lines of text from handwritten characters: `EMNISTLines`" ] }, { "cell_type": "markdown", "metadata": { "id": "FevtQpeDtTFM" }, "source": [ "We've got a training pipeline for our model and our data,\n", "and we can use that to make the loss go down\n", "and get better at the task.\n", "But the problem we're solving not obviously useful:\n", "the model is just learning how to handle\n", "centered, high-contrast, isolated characters.\n", "\n", "To make this work in a text recognition application,\n", "we would need a component to first pull out characters like that from images.\n", "That task is probably harder than the one we're currently learning.\n", "Plus, splitting into two separate components is against the ethos of deep learning,\n", "which operates \"end-to-end\".\n", "\n", "Let's kick the realism up one notch by building lines of text out of our characters:\n", "_synthesizing_ data for our model." ] }, { "cell_type": "markdown", "metadata": { "id": "dH7i4JhWe7ch" }, "source": [ "Synthetic data is generally useful for augmenting limited real data.\n", "By construction we know the labels, since we created the data.\n", "Often, we can track covariates,\n", "like lighting features or subclass membership,\n", "that aren't always available in our labels." ] }, { "cell_type": "markdown", "metadata": { "id": "TrQ_44TIe39m" }, "source": [ "To build fake handwriting,\n", "we'll combine two things:\n", "real handwritten letters and real text.\n", "\n", "We generate our fake text by drawing from the\n", "[Brown corpus](https://en.wikipedia.org/wiki/Brown_Corpus)\n", "provided by the [`n`atural `l`anguage `t`ool`k`it](https://www.nltk.org/) library.\n", "\n", "First, we download that corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gtSg7Y8Ydxpa" }, "outputs": [], "source": [ "from text_recognizer.data.sentence_generator import SentenceGenerator\n", "\n", "sentence_generator = SentenceGenerator()\n", "\n", "SentenceGenerator.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "yal5eHk-aB4i" }, "source": [ "We can generate short snippets of text from the corpus with the `SentenceGenerator`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eRg_C1TYzwKX" }, "outputs": [], "source": [ "print(*[sentence_generator.generate(max_length=16) for _ in range(4)], sep=\"\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JGsBuMICaXnM" }, "source": [ "We use another `DataModule` to pick out the needed handwritten characters from `EMNIST`\n", "and glue them together into images containing the generated text." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YtsGfSu6dpZ9" }, "outputs": [], "source": [ "emnist_lines = text_recognizer.data.EMNISTLines() # configure\n", "emnist_lines.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "dik_SyEdb0st" }, "source": [ "This can take several minutes when first run,\n", "but afterwards data is persisted to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SofIYHOUtTFM" }, "outputs": [], "source": [ "emnist_lines.prepare_data() # download, save to disk\n", "emnist_lines.setup() # create torch.utils.data.Datasets, do train/val split\n", "emnist_lines" ] }, { "cell_type": "markdown", "metadata": { "id": "axESuV1SeoM6" }, "source": [ "Again, we're using the `LightningDataModule` interface\n", "to organize our data prep,\n", "so we can now fetch a batch and take a look at some data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1J7f2I9ggBi-" }, "outputs": [], "source": [ "line_xs, line_ys = next(iter(emnist_lines.val_dataloader()))\n", "line_xs.shape, line_ys.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B0yHgbW2gHgP" }, "outputs": [], "source": [ "def read_line_labels(labels):\n", " return [emnist_lines.mapping[label] for label in labels]\n", "\n", "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "print(\"-\".join(read_line_labels(line_ys[idx])))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xirEmNPNtTFM" }, "source": [ "The result looks\n", "[kind of like a ransom note](https://tvtropes.org/pmwiki/pmwiki.php/Main/CutAndPasteNote)\n", "and is not yet anywhere near realistic, even for single lines --\n", "letters don't overlap, the exact same handwritten letter is repeated\n", "if the character appears more than once in the snippet --\n", "but it's a start." ] }, { "cell_type": "markdown", "metadata": { "id": "eRWbSzkotTFM" }, "source": [ "# Applying CNNs to handwritten text: `LineCNNSimple`" ] }, { "cell_type": "markdown", "metadata": { "id": "pzwYBv82tTFM" }, "source": [ "The `LineCNNSimple` class builds on the `CNN` class and can be applied to this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZqeImjd2lF7p" }, "outputs": [], "source": [ "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "line_cnn" ] }, { "cell_type": "markdown", "metadata": { "id": "Hi6g0acoxJO4" }, "source": [ "The `nn.Module`s look much the same,\n", "but the way they are used is different,\n", "which we can see by examining the `.forward` method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Qg3UJhibxHfC" }, "outputs": [], "source": [ "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "LAW7EWVlxMhd" }, "source": [ "The `CNN`, which operates on square images,\n", "is applied to our wide image repeatedly,\n", "slid over by the `W`indow `S`ize each time.\n", "We effectively convolve the network with the input image.\n", "\n", "Like our synthetic data, it is crude\n", "but it's enough to get started." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FU4J13yLisiC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = line_cnn(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "OxHI4Gzndbxg" }, "source": [ "> You may notice that this randomly-initialized\n", "network tends to predict some characters far more often than others,\n", "rather than predicting all characters with equal likelihood.\n", "This is a commonly-observed phenomenon in deep networks.\n", "It is connected to issues with\n", "[model calibration](https://arxiv.org/abs/1706.04599)\n", "and Bayesian uses of DNNs\n", "(see e.g. Figure 7 of\n", "[Wenzel et al. 2020](https://arxiv.org/abs/2002.02405))." ] }, { "cell_type": "markdown", "metadata": { "id": "NSonI9KcfJrB" }, "source": [ "Let's launch a training run with the default parameters.\n", "\n", "This cell should run in just a few minutes on typical hardware." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsbJdeRiwSVA" }, "outputs": [], "source": [ "%run training/run_experiment.py --model_class LineCNNSimple --data_class EMNISTLines \\\n", " --batch_size 32 --gpus {gpus} --max_epochs 2" ] }, { "cell_type": "markdown", "metadata": { "id": "y9e5nTplfoXG" }, "source": [ "You should see a test accuracy in the 65-70% range.\n", "\n", "That seems pretty good,\n", "especially for a simple model trained in a minute.\n", "\n", "Let's reload the model and run it on some examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0NuXazAvw9NA" }, "outputs": [], "source": [ "# if you change around model/data args in the command above, add them here\n", "# tip: define the arguments as variables, like we've done for gpus\n", "# and then add those variables to this dict so you don't need to\n", "# remember to update/copy+paste\n", "\n", "args = Namespace(**{\n", " \"model_class\": \"LineCNNSimple\",\n", " \"data_class\": \"EMNISTLines\"})\n", "\n", "\n", "_, line_cnn = training.util.setup_data_and_model_from_args(args)\n", "\n", "latest_ckpt, = ! {list_all_log_files} | {filter_to_ckpts} | {sort_version_descending} | {take_first}\n", "print(latest_ckpt)\n", "\n", "reloaded_lines_model = text_recognizer.lit_models.BaseLitModel.load_from_checkpoint(\n", " latest_ckpt, args=args, model=line_cnn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "J8ziVROkxkGC" }, "outputs": [], "source": [ "idx = random.randint(0, len(line_xs) - 1)\n", "\n", "outs, = reloaded_lines_model(line_xs[idx:idx+1])\n", "preds = torch.argmax(outs, 0)\n", "\n", "print(\"-\".join(read_line_labels(preds)))\n", "wandb.Image(line_xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "N9bQCHtYgA0S" }, "source": [ "In general,\n", "we see predictions that have very low subjective quality:\n", "it seems like most of the letters are wrong\n", "and the model often prefers to predict the most common letters\n", "in the dataset, like `e`.\n", "\n", "Notice, however, that many of the\n", "characters in a given line are padding characters, `

`.\n", "\n", "A model that always predicts `

` can achieve around 50% accuracy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EE-T7zgDgo7-" }, "outputs": [], "source": [ "padding_token = emnist_lines.emnist.inverse_mapping[\"

\"]\n", "torch.sum(line_ys == padding_token) / line_ys.numel()" ] }, { "cell_type": "markdown", "metadata": { "id": "rGHWmOyVh5rV" }, "source": [ "There are ways to adjust your classification metrics to\n", "[handle this particular issue](https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall).\n", "In general it's good to find a metric\n", "that has baseline performance at 0 and perfect performance at 1,\n", "so that numbers are clearly interpretable.\n", "\n", "But it's an important reminder to actually look\n", "at your model's behavior from time to time.\n", "Metrics are single numbers,\n", "so they by necessity throw away a ton of information\n", "about your model's behavior,\n", "some of which is deeply relevant." ] }, { "cell_type": "markdown", "metadata": { "id": "6p--KWZ9YJWQ" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "srQnoOK8YLDv" }, "source": [ "### 🌟 Research a `pl.Trainer` argument and try it out." ] }, { "cell_type": "markdown", "metadata": { "id": "7j652MtkYR8n" }, "source": [ "The Lightning `Trainer` class is highly configurable\n", "and has accumulated a number of features as Lightning has matured.\n", "\n", "Check out the documentation for this class\n", "and pick an argument to try out with `training/run_experiment.py`.\n", "Look for edge cases in its behavior,\n", "especially when combined with other arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8UWNicq_jS7k" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "pl_version = pl.__version__\n", "\n", "print(\"pl.Trainer guide URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/common/trainer.html\")\n", "print(\"pl.Trainer reference docs URL:\", f\"https://pytorch-lightning.readthedocs.io/en/{pl_version}/api/pytorch_lightning.trainer.trainer.Trainer.html\")\n", "\n", "pl.Trainer??" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "14AOfjqqYOoT" }, "outputs": [], "source": [ "%run training/run_experiment.py --help" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "lab02b_cnn.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab08/notebooks/lab03_transformers.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 03: Transformers and Paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- The fundamental reasons why the Transformer is such\n", "a powerful and popular architecture\n", "- Core intuitions for the behavior of Transformer architectures\n", "- How to use a convolutional encoder and a Transformer decoder to recognize\n", "entire paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 3\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "XZN4bGgsgWc_" }, "source": [ "# Why Transformers?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our goal in building a text recognizer is to take a two-dimensional image\n", "and convert it into a one-dimensional sequence of characters\n", "from some alphabet." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional neural networks,\n", "discussed in [Lab 02b](https://fsdl.me/lab02b-colab),\n", "are great at encoding images,\n", "taking them from their raw pixel values\n", "to a more semantically meaningful numerical representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But how do we go from that to a sequence of letters?\n", "And what's especially tricky:\n", "the number of letters in an image is separable from its size.\n", "A screenshot of this document has a much higher density of letters\n", "than a close-up photograph of a piece of paper.\n", "How do we get a _variable-length_ sequence of letters,\n", "where the length need have nothing to do with the size of the input tensor?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_Transformers_ are an encoder-decoder architecture that excels at sequence modeling --\n", "they were\n", "[originally introduced](https://arxiv.org/abs/1706.03762)\n", "for transforming one sequence into another,\n", "as in machine translation.\n", "This makes them a natural fit for processing language.\n", "\n", "But they have also found success in other domains --\n", "at the time of this writing, large transformers\n", "dominate the\n", "[ImageNet classification benchmark](https://paperswithcode.com/sota/image-classification-on-imagenet)\n", "that has become a de facto standard for comparing models\n", "and are finding\n", "[application in reinforcement learning](https://arxiv.org/abs/2106.01345)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So we will use a Transformer as a key component of our final architecture:\n", "we will encode our input images with a CNN\n", "and then read them out into a text sequence with a Transformer.\n", "\n", "Before trying out this new model,\n", "let's first get an understanding of why the Transformer architecture\n", "has become so popular by walking through its history\n", "and then get some intuition for how it works\n", "by looking at some\n", "[recent work](https://transformer-circuits.pub/)\n", "on explaining the behavior of both toy models and state-of-the-art language models." ] }, { "cell_type": "markdown", "metadata": { "id": "kmKqjbvd-Mj3" }, "source": [ "## Why not convolutions?" ] }, { "cell_type": "markdown", "metadata": { "id": "SRqkUMdM-OxU" }, "source": [ "In the ancient beforetimes (i.e. 2016),\n", "the best models for natural language processing were all\n", "_recurrent_ neural networks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convolutional networks were also occasionally used,\n", "but they suffered from a serious issue:\n", "their architectural biases don't fit text.\n", "\n", "First, _translation equivariance_ no longer holds.\n", "The beginning of a piece of text is often quite different from the middle,\n", "so the absolute position matters.\n", "\n", "Second, _locality_ is not as important in language.\n", "The name of a character that hasn't appeared in thousands of pages\n", "can become salient when someone asks, \"Whatever happened to\n", "[Radagast the Brown](https://tvtropes.org/pmwiki/pmwiki.php/ChuckCunninghamSyndrome/Literature)?\"\n", "\n", "Consider interpreting a piece of text like the Python code below:\n", "```python\n", "def do(arg1, arg2, arg3):\n", " a = arg1 + arg2\n", " b = arg3[:3]\n", " c = a * b\n", " return c\n", "\n", "print(do(1, 1, \"ayy lmao\"))\n", "```\n", "\n", "After a `(` we expect a `)`,\n", "but possibly very long afterwards,\n", "[e.g. in the definition of `pl.Trainer.__init__`](https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.__init__),\n", "and similarly we expect a `]` at some point after a `[`.\n", "\n", "For translation variance, consider\n", "that we interpret `*` not by\n", "comparing it to its neighbors\n", "but by looking at `a` and `b`.\n", "We mix knowledge learned through experience\n", "with new facts learned while reading --\n", "also known as _in-context learning_.\n", "\n", "In a longer text,\n", "[e.g. the one you are reading now](./lab03_transformers.ipynb),\n", "the translation variance of text is clearer.\n", "Every lab notebook begins with the same header,\n", "setting up the environment,\n", "but that header never appears elsewhere in the notebook.\n", "Later positions need to be processed in terms of the previous entries.\n", "\n", "Unlike an image, we cannot simply rotate or translate our \"camera\"\n", "and get a new valid text.\n", "[Rare is the book](https://en.wikipedia.org/wiki/Dictionary_of_the_Khazars)\n", "that can be read without regard to position." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The field of formal language theory,\n", "which has deep mutual influence with computer science,\n", "gives one way of explaining the issues with convolutional networks:\n", "they can only understand languages with _finite contexts_,\n", "where all the information can be found within a finite window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The immediate solution, drawing from the connections to computer science, is\n", "[recursion](https://www.google.com/search?q=recursion).\n", "A network whose output on the final entry of the sequence is a recursive function\n", "of all the previous entries can build up knowledge\n", "as it reads the sequence and treat early entries quite differently than it does late ones." ] }, { "cell_type": "markdown", "metadata": { "id": "aa6cbTlImkEh" }, "source": [ "In pseudo-code, such a _recurrent neural network_ module might look like:" ] }, { "cell_type": "markdown", "metadata": { "id": "lKtBoPnglPrW" }, "source": [ "```python\n", "def recurrent_module(xs: torch.Tensor[\"S\", \"input_dims\"]) -> torch.Tensor[\"feature_dims\"]:\n", " next_inputs = input_module(xs[-1])\n", " next_hiddens = feature_module(recurrent_module(xs[:-1])) # recursive call\n", " return output_module(next_inputs, next_hiddens)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "IbJPSMnEm516" }, "source": [ "If you've had formal computer science training,\n", "then you may be familiar with the power of recursion,\n", "e.g. the\n", "[Y-combinator](https://en.wikipedia.org/wiki/Fixed-point_combinator#Y_combinator)\n", "that gave its name to the now much better-known\n", "[startup incubator](https://www.ycombinator.com/).\n", "\n", "The particular form of recursion used by\n", "recurrent neural networks implements a\n", "[reduce-like operation](https://colah.github.io/posts/2015-09-NN-Types-FP/).\n", "\n", "> If you've know a lot of computer science,\n", "you might be concerned by this connection.\n", "What about other\n", "[recursion schemes](https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html)?\n", "Where are the neural network architectures for differentiable\n", "[zygohistomorphic prepromorphisms](https://wiki.haskell.org/Zygohistomorphic_prepromorphisms)?\n", "Check out Graph Neural Networks,\n", "[which implement dynamic programming](https://arxiv.org/abs/2203.15544)." ] }, { "cell_type": "markdown", "metadata": { "id": "63mMTbEBpVuE" }, "source": [ "Recurrent networks are able to achieve\n", "[decent results in language modeling and machine translation](https://paperswithcode.com/paper/regularizing-and-optimizing-lstm-language).\n", "\n", "There are many popular recurrent architectures,\n", "from the beefy and classic\n", "[LSTM](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) \n", "and the svelte and modern [GRU](https://arxiv.org/abs/1412.3555)\n", "([no relation](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/gru.jpeg)),\n", "all of which have roughly similar capabilities but\n", "[some of which are easier to train](https://arxiv.org/abs/1611.09913)." ] }, { "cell_type": "markdown", "metadata": { "id": "PwQHVTIslOku" }, "source": [ "In the same sense that MLPs can model \"any\" feedforward function,\n", "in principle even basic RNNs\n", "[can model \"any\" dynamical system](https://www.sciencedirect.com/science/article/abs/pii/S089360800580125X).\n", "\n", "In particular they can model any\n", "[Turing machine](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis),\n", "which is a formal way of saying that they can in principle\n", "do anything a computer is capable of doing.\n", "\n", "The question is then..." ] }, { "cell_type": "markdown", "metadata": { "id": "3J8EoGN3pu7P" }, "source": [ "## Why aren't we all using RNNs?" ] }, { "cell_type": "markdown", "metadata": { "id": "TDwNWaevpt_3" }, "source": [ "The guarantees that MLPs can model any function\n", "or that RNNs can model Turing machines\n", "provide decent intuition but are not directly practically useful.\n", "Among other reasons, they don't guarantee learnability --\n", "that starting from random parameters we can find the parameters\n", "that implement a given function.\n", "The\n", "[effective capacity of neural networks is much lower](https://arxiv.org/abs/1901.09021)\n", "than would seem from basic theoretical and empirical analysis.\n", "\n", "One way of understanding capacity to model language is\n", "[the Chomsky hierarchy](https://en.wikipedia.org/wiki/Chomsky_hierarchy).\n", "In this model of formal languages,\n", "Turing machines sit at the top\n", "([practically speaking](https://arxiv.org/abs/math/0209332)).\n", "\n", "With better mathematical models,\n", "RNNs and LSTMs can be shown to be\n", "[much weaker within the Chomsky hierarchy](https://arxiv.org/abs/2102.10094),\n", "with RNNs looking more like\n", "[a regex parser](https://en.wikipedia.org/wiki/Finite-state_machine#Acceptors)\n", "and LSTMs coming in\n", "[just above them](https://en.wikipedia.org/wiki/Counter_automaton).\n", "\n", "More controversially:\n", "the Chomsky hierarchy is great for understanding syntax and grammar,\n", "which makes it great for building parsers\n", "and working with formal languages,\n", "but the goal in _natural_ language processing is to understand _natural_ language.\n", "Most humans' natural language is far from strictly grammatical,\n", "but that doesn't mean it is nonsense.\n", "\n", "And to really \"understand\" language means\n", "to understand its semantic content, which is fuzzy.\n", "The most important thing for handling the fuzzy semantic content\n", "of language is not whether you can recall\n", "[a parenthesis arbitrarily far in the past](https://en.wikipedia.org/wiki/Dyck_language)\n", "but whether you can model probabilistic relationships between concepts\n", "in addition to grammar and syntax." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These both leave theoretical room for improvement over current recurrent\n", "language and sequence models.\n", "\n", "But the real cause of the rise of Transformers is that..." ] }, { "cell_type": "markdown", "metadata": { "id": "Dsu1ebvAp-3Z" }, "source": [ "## Transformers are designed to train fast at scale on contemporary hardware." ] }, { "cell_type": "markdown", "metadata": { "id": "c4abU5adsPGs" }, "source": [ "The Transformer architecture has several important features,\n", "discussed below,\n", "but one of the most important reasons why it is successful\n", "is because it can be more easily trained at scale.\n", "\n", "This scalability is the focus of the discussion in the paper\n", "that introduced the architecture,\n", "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", "and\n", "[comes up whenever there's speculation about scaling up recurrent models](https://twitter.com/jekbradbury/status/1550928156504100864).\n", "\n", "The recursion in RNNs is inherently sequential:\n", "the dependence on the outputs from earlier in the sequence\n", "means computations within an example cannot be parallelized.\n", "\n", "So RNNs must batch across examples to scale,\n", "but as sequence length grows this hits memorybandwidth limits.\n", "Serving up large batches quickly with good randomness guarantees\n", "is also hard to optimize,\n", "especially in distributed settings.\n", "\n", "The Transformer architecture,\n", "on the other hand,\n", "can be readily parallelized within a single example sequence,\n", "in addition to parallelization across batches.\n", "This can lead to massive performance gains for a fixed scale,\n", "which means larger, higher capacity models\n", "can be trained on larger datasets." ] }, { "cell_type": "markdown", "metadata": { "id": "_Mzk2haFC_G1" }, "source": [ "How does the architecture achieve this parallelizability?\n", "\n", "Let's start with the architecture diagram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u59eu4snLQfp" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "base_url = \"https://fsdl-public-assets.s3.us-west-2.amazonaws.com\"\n", "\n", "display.Image(url=base_url + \"/aiayn-figure-1.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ez-XEQ7M0UlR" }, "source": [ "> To head off a bit of confusion\n", " in case you've worked with Transformer architectures before:\n", " the original \"Transformer\" is an encoder/decoder architecture.\n", " Many LLMs, like GPT models, are decoder only,\n", " because this has turned out to scale well,\n", " and in NLP you can always just make the inputs part of the \"outputs\" by prepending --\n", " it's all text anyways.\n", " We, however, will be using them across modalities,\n", " so we need an explicit encoder,\n", " as above. " ] }, { "cell_type": "markdown", "metadata": { "id": "ok4ksBi4vp89" }, "source": [ "First focusing on the encoder (left):\n", "the encoding at a given position is a function of all previous inputs.\n", "But it is not a function of the previous _encodings_:\n", "we produce the encodings \"all at once\"." ] }, { "cell_type": "markdown", "metadata": { "id": "RPN7C-_OqzHP" }, "source": [ "The decoder (right) does use previous \"outputs\" as its inputs,\n", "but those outputs are not the vectors of layer activations\n", "(aka embeddings)\n", "that are produced by the network.\n", "They are instead the processed outputs,\n", "after a `softmax` and an `argmax`.\n", "\n", "We could obtain these outputs by processing the embeddings,\n", "much like in a recurrent architecture.\n", "In fact, that is one way that Transformers are run.\n", "It's what happens in the `.forward` method\n", "of the model we'll be training for character recognition:\n", "`ResnetTransformer`." ] }, { "cell_type": "markdown", "metadata": { "id": "L5_2WMmtDnJn" }, "source": [ "Let's look at that forward method\n", "and connect it to the diagram." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FR5pk4kEyCGg" }, "outputs": [], "source": [ "from text_recognizer.models import ResnetTransformer\n", "\n", "\n", "ResnetTransformer.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "-J5UFDoPzPbq" }, "source": [ "`.encode` happens first -- that's the left side of diagram.\n", "\n", "The encoder can in principle be anything\n", "that produces a sequence of fixed-length vectors,\n", "but here it's\n", "[a `ResNet` implementation from `torchvision`](https://pytorch.org/vision/stable/models.html).\n", "\n", "Then we start iterating over the sequence\n", "in the `for` loop.\n", "\n", "Focus on the first few lines of code.\n", "We apply `.decode` (right side of diagram)\n", "to the outputs so far.\n", "\n", "Once we have a new `output`, we apply `.argmax`\n", "to turn the logits into a concrete prediction of\n", "a particular token.\n", "\n", "This is added as the last output token\n", "and then the loop happens again." ] }, { "cell_type": "markdown", "metadata": { "id": "LTcy8-rV1dHr" }, "source": [ "Run this way, our model looks very much like a recurrent architecture:\n", "we call the model on its own outputs\n", "to generate the next value.\n", "These types of models are also referred to as\n", "[autoregressive models](https://deepgenerativemodels.github.io/notes/autoregressive/),\n", "because we predict (as we do in _regression_)\n", "the next value based on our own (_auto_) output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But Transformers are designed to be _trained_ more scalably than RNNs,\n", "not necessarily to _run inference_ more scalably,\n", "and it's actually not the case that our model's `.forward` is called during training." ] }, { "cell_type": "markdown", "metadata": { "id": "eCxMSAWmEKBt" }, "source": [ "Let's look at what happens during training\n", "by checking the `training_step`\n", "of the `LightningModule`\n", "we use to train our Transformer models,\n", "the `TransformerLitModel`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0o7q8N7P2w4H" }, "outputs": [], "source": [ "from text_recognizer.lit_models import TransformerLitModel\n", "\n", "TransformerLitModel.training_step??" ] }, { "cell_type": "markdown", "metadata": { "id": "1VgNNOjvzC4y" }, "source": [ "Notice that we call `.teacher_forward` on the inputs, instead of `model.forward`." ] }, { "cell_type": "markdown", "metadata": { "id": "tz-6NGPR4dUr" }, "source": [ "Let's look at `.teacher_forward`,\n", "and in particular its type signature:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ILc2oWET4i2Z" }, "outputs": [], "source": [ "TransformerLitModel.teacher_forward??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function uses both inputs `x` _and_ ground truth targets `y` to produce the `outputs`." ] }, { "cell_type": "markdown", "metadata": { "id": "lf32lpgrDb__" }, "source": [ "This is known as \"teacher forcing\".\n", "The \"teacher\" signal is \"forcing\"\n", "the model to behave as though\n", "it got the answer right.\n", "\n", "[Teacher forcing was originally developed for RNNs](https://direct.mit.edu/neco/article-abstract/1/2/270/5490/A-Learning-Algorithm-for-Continually-Running-Fully).\n", "It's more effective here\n", "because the right teaching signal\n", "for our network is the target data,\n", "which we have access to during training,\n", "whereas in an RNN the best teaching signal\n", "would be the target embedding vector,\n", "which we do not know.\n", "\n", "During inference, when we don't have access to the ground truth,\n", "we revert to the autoregressive `.forward` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This \"trick\" allows Transformer architectures to readily scale\n", "up models to the parameter counts\n", "[required to make full use of internet-scale datasets](https://arxiv.org/abs/2001.08361)." ] }, { "cell_type": "markdown", "metadata": { "id": "BAjqpJm9uUuU" }, "source": [ "## Is there more to Transformers more than just a training trick?" ] }, { "cell_type": "markdown", "metadata": { "id": "kWCYXeHv7Qc9" }, "source": [ "[Very](https://arxiv.org/abs/2005.14165),\n", "[very](https://arxiv.org/abs/1909.08053),\n", "[very](https://arxiv.org/abs/2205.01068)\n", "large Transformer models have powered the most recent wave of exciting results in ML, like\n", "[photorealistic high-definition image generation](https://cdn.openai.com/papers/dall-e-2.pdf).\n", "\n", "They are also the first machine learning models to have come anywhere close to\n", "deserving the term _artificial intelligence_ --\n", "a slippery concept, but \"how many Turing-type tests do you pass?\" is a good barometer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is surprising because the models and their training procedure are\n", "(relatively speaking)\n", "pretty _simple_,\n", "even if it doesn't feel that way on first pass." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic Transformer architecture is just a bunch of\n", "dense matrix multiplications and non-linearities --\n", "it's perhaps simpler than a convolutional architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And advances since the introduction of Transformers in 2017\n", "have not in the main been made by\n", "creating more sophisticated model architectures\n", "but by increasing the scale of the base architecture,\n", "or if anything making it simpler, as in\n", "[GPT-type models](https://arxiv.org/abs/2005.14165),\n", "which drop the encoder." ] }, { "cell_type": "markdown", "metadata": { "id": "V1HQS9ey8GMc" }, "source": [ "These models are also trained on very simple tasks:\n", "most LLMs are just trying to predict the next element in the sequence,\n", "given the previous elements --\n", "a task simple enough that Claude Shannon,\n", "father of information theory, was\n", "[able to work on it in the 1950s](https://www.princeton.edu/~wbialek/rome/refs/shannon_51.pdf).\n", "\n", "These tasks are chosen because it is easy to obtain extremely large-scale datasets,\n", "e.g. by scraping the web." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They are also trained in a simple fashion:\n", "first-order stochastic optimizers, like SGD or an\n", "[ADAM variant](https://optimization.cbe.cornell.edu/index.php?title=Adam),\n", "intended for the most basic of optimization problems,\n", "that scale more readily than the second-order optimizers\n", "that dominate other areas of optimization." ] }, { "cell_type": "markdown", "metadata": { "id": "Kz9HPDoy7OAl" }, "source": [ "This is\n", "[the bitter lesson](http://www.incompleteideas.net/IncIdeas/BitterLesson.html)\n", "of work in ML:\n", "simple, even seemingly wasteful,\n", "architectures that scale well and are robust\n", "to implementation details\n", "eventually outstrip more clever but\n", "also more finicky approaches that are harder to scale.\n", "This lesson has led some to declare that\n", "[scale is all you need](https://fsdl-public-assets.s3.us-west-2.amazonaws.com/siayn.jpg)\n", "in machine learning, and perhaps even in artificial intelligence." ] }, { "cell_type": "markdown", "metadata": { "id": "SdN9o2Y771YZ" }, "source": [ "> That is not to say that because the algorithms are relatively simple,\n", " training a model at this scale is _easy_ --\n", " [datasets require cleaning](https://openreview.net/forum?id=UoEw6KigkUn),\n", " [model architectures require tuning and hyperparameter selection](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2),\n", " [distributed systems require care and feeding](https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/chronicles/OPT175B_Logbook.pdf).\n", " But choosing the simplest algorithm at every step makes solving the scaling problem feasible." ] }, { "cell_type": "markdown", "metadata": { "id": "baVGf6gKFOvs" }, "source": [ "The importance of scale is the key lesson from the Transformer architecture,\n", "far more than any theoretical considerations\n", "or any of the implementation details.\n", "\n", "That said, these large Transformer models are capable of\n", "impressive behaviors and understanding how they achieve them\n", "is of intellectual interest.\n", "Furthermore, like any architecture,\n", "there are common failure modes,\n", "of the model and of the modelers who use them,\n", "that need to be taken into account." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "Below, we'll cover two key intuitions about Transformers:\n", "Transformers are _residual_, like ResNets,\n", "and they compose _low rank_ sequence transformations.\n", "Together, this means they act somewhat like a computer,\n", "reading from and writing to a \"tape\" or memory\n", "with a sequence of simple instructions." ] }, { "cell_type": "markdown", "metadata": { "id": "1t2Cfq9Fq67Q" }, "source": [ "We'll also cover a surprising implementation detail:\n", "despite being commonly used for sequence modeling,\n", "by default the architecture is _position insensitive_." ] }, { "cell_type": "markdown", "metadata": { "id": "uni0VTCr9lev" }, "source": [ "### Intuition #1: Transformers are highly residual." ] }, { "cell_type": "markdown", "metadata": { "id": "0MoBt-JLJz-d" }, "source": [ "> The discussion of these inuitions summarizes the discussion in\n", "[A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html)\n", "from\n", "[Anthropic](https://www.anthropic.com/),\n", "an AI safety and research company.\n", "The figures below are from that blog post.\n", "It is the spiritual successor to the\n", "[Circuits Thread](https://distill.pub/2020/circuits/)\n", "covered in\n", "[Lab 02b](https://lab02b-colab).\n", "If you want to truly understand Transformers,\n", "we highly recommend you check it out,\n", "including the\n", "[associated exercises](https://transformer-circuits.pub/2021/exercises/index.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "UUbNVvM5Ferm" }, "source": [ "It's easy to see that ResNets are residual --\n", "it's in the name, after all.\n", "\n", "But Transformers are,\n", "in some sense,\n", "even more closely tied to residual computation\n", "than are ResNets:\n", "ResNets and related architectures include downsampling,\n", "so there is not a direct path from inputs to outputs.\n", "\n", "In Transformers, the exact same shape is maintained\n", "from the moment tokens are embedded,\n", "through dozens or hundreds of intermediate layers,\n", "and until they are \"unembedded\" into class logits.\n", "The Transformer Circuits authors refer to this pathway as the \"residual stream\".\n", "\n", "The resiudal stream is easy to see with a change of perspective.\n", "Instead of the usual architecture diagram above,\n", "which emphasizes the layers acting on the tensors,\n", "consider this alternative view,\n", "which emphasizes the tensors as they pass through the layers:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HRMlVguKKW6y" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-residual-view.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9K3N7ilVkB3" }, "source": [ "For definitions of variables and terms, see the\n", "[notation reference here](https://transformer-circuits.pub/2021/framework/index.html#notation)." ] }, { "cell_type": "markdown", "metadata": { "id": "arvciE-kKd_L" }, "source": [ "Note that this is a _decoder-only_ Transformer architecture --\n", "so it should be compared with the right-hand side of the original architecture diagram above." ] }, { "cell_type": "markdown", "metadata": { "id": "wvrRMd_RKp_G" }, "source": [ "Notice that outputs of the attention blocks \n", "and of the MLP layers are\n", "added to their inputs, as in a ResNet.\n", "These operations are represented as \"Add & Norm\" layers in the classical diagram;\n", "normalization is ignored here for simplicity." ] }, { "cell_type": "markdown", "metadata": { "id": "o8n_iT-FFAbK" }, "source": [ "This total commitment to residual operations\n", "means the size of the embeddings\n", "(referred to as the \"model dimension\" or the \"embedding dimension\",\n", "here and below `d_model`)\n", "stays the same throughout the entire network.\n", "\n", "That means, for example,\n", "that the output of each layer can be used as input to the \"unembedding\" layer\n", "that produces logits.\n", "We can read out the computations of intermediate layers\n", "just by passing them through the unembedding layer\n", "and examining the logit tensor.\n", "See\n", "[\"interpreting GPT: the logit lens\"](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)\n", "for detailed experiments and interactive notebooks.\n", "\n", "In short, we observe a sort of \"progressive refinement\"\n", "of the next-token prediction\n", "as the embeddings proceed, depthwise, through the network." ] }, { "cell_type": "markdown", "metadata": { "id": "Ovh_3YgY9z2h" }, "source": [ "### Intuition #2 Transformer heads learn low rank transformations." ] }, { "cell_type": "markdown", "metadata": { "id": "XpNmozlnOdPC" }, "source": [ "In the original paper and in\n", "most presentations of Transformers,\n", "the attention layer is written like so:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PA7me8gNP5LE" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(Q \\cdot K^T) \\cdot V$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In pseudo-typed PyTorch (based loosely on\n", "[`torchtyping`](https://github.com/patrick-kidger/torchtyping))\n", "that looks like:" ] }, { "cell_type": "markdown", "metadata": { "id": "Oeict_6wGJgD" }, "source": [ "```python\n", "def classic_attention(\n", " Q: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " K: torch.Tensor[\"d_sequence\", \"d_model\"],\n", " V: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " return torch.softmax(Q @ K.T) @ V\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "8pewU90DSuOR" }, "source": [ "This is effectively exactly\n", "how it is written\n", "in PyTorch,\n", "apart from implementation details\n", "(look for `bmm` for the matrix multiplications and a `softmax` call):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WrgTpKFvOhwc" }, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "F._scaled_dot_product_attention??" ] }, { "cell_type": "markdown", "metadata": { "id": "ebDXZ0tlSe7g" }, "source": [ "But the best way to write an operation so that a computer can execute it quickly\n", "is not necessarily the best way to write it so that a human can understand it --\n", "otherwise we'd all be coding in assembly.\n", "\n", "And this is a strange way to write it --\n", "you'll notice that what we normally think of\n", "as the \"inputs\" to the layer are not shown.\n", "\n", "We can instead write out the attention layer\n", "as a function of the inputs $x$.\n", "We write it for a single \"attention head\".\n", "Each attention layer includes a number of heads\n", "that read and write from the residual stream\n", "simultaneously and independently.\n", "We also add the output layer weights $W_O$\n", "and we get:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LuFNR67tQpsf" }, "outputs": [], "source": [ "display.Latex(r\"$\\text{softmax}(\\underbrace{x^TW_Q^T}_Q \\underbrace{W_Kx}_{K^T}) \\underbrace{x W_V^T}_V W_O^T$\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SVnBjjfOLwxP" }, "source": [ "or, in pseudo-typed PyTorch:" ] }, { "cell_type": "markdown", "metadata": { "id": "LmpOm-HfGaNz" }, "source": [ "```python\n", "def rewrite_attention_single_head(x: torch.Tensor[\"d_sequence\", \"d_model\"]) -> torch.Tensor[\"d_sequence\", \"d_model\"]:\n", " query_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_Q\n", " key_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_K\n", " key_query_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_Q.T @ W_K\n", " # maps queries of residual stream to keys from residual stream, independent of position\n", "\n", " value_weights: torch.Tensor[\"d_head\", \"d_model\"] = W_V\n", " output_weights: torch.Tensor[\"d_model\", \"d_head\"] = W_O\n", " value_output_circuit: torch.Tensor[\"d_model\", \"d_model\"] = W_V.T @ W_O.T\n", " # transformation applied to each token, regardless of position\n", "\n", " attention_logits = x.T @ key_query_circuit @ x\n", " attention_map: torch.Tensor[\"d_sequence\", \"d_sequence\"] = torch.softmax(attention_logits)\n", " # maps positions to positions, often very sparse\n", "\n", " value_output: torch.Tensor[\"d_sequence\", \"d_model\"] = x @ value_output_circuit\n", "\n", " return attention_map @ value_output # transformed tokens filtered by attention map\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "dC0eqxZ6UAGT" }, "source": [ "Consider the `key_query_circuit`\n", "and `value_output_circuit`\n", "matrices, $W_{QK} := W_Q^TW_K$ and $W_{OV}^T := W_V^TW_O^T$\n", "\n", "The key/query dimension, `d_head`\n", "is small relative to the model's dimension, `d_model`,\n", "so $W_{QK}$ and $W_{OV}$ are very low rank,\n", "[which is the same as saying](https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Decomposition_rank)\n", "that they factorize into two matrices,\n", "one with a smaller number of rows\n", "and another with a smaller number of columns.\n", "That number is called the _rank_.\n", "\n", "When computing, these matrices are better represented via their components,\n", "rather than computed directly,\n", "which leads to the normal implementation of attention.\n", "\n", "In a large language model,\n", "the ratio of residual stream dimension, `d_model`, to\n", "the dimension of a single head, `d_head`, is huge, often 100:1.\n", "That means each query, key, and value computed at a position\n", "is a fairly simple, low-dimensional feature of the residual stream at that position.\n", "\n", "For visual intuition,\n", "we compare what a matrix with a rank 100th of full rank looks like,\n", "relative to a full rank matrix of the same size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_LUbojJMiW2C" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "\n", "low_rank = torch.randn(100, 1) @ torch.randn(1, 100)\n", "full_rank = torch.randn(100, 100)\n", "plt.figure(); plt.title(\"rank 1/100 matrix\"); plt.imshow(low_rank, cmap=\"Greys\"); plt.axis(\"off\")\n", "plt.figure(); plt.title(\"rank 100/100 matrix\"); plt.imshow(full_rank, cmap=\"Greys\"); plt.axis(\"off\");" ] }, { "cell_type": "markdown", "metadata": { "id": "lqBst92-OVka" }, "source": [ "The pattern in the first matrix is very simple,\n", "relative to the pattern in the second matrix." ] }, { "cell_type": "markdown", "metadata": { "id": "SkCGrs9EiVh4" }, "source": [ "Another feature of low rank transformations is\n", "that they have a large nullspace or kernel --\n", "these are directions we can move the input without changing the output.\n", "\n", "That means that many changes to the residual stream won't affect the behavior of this head at all." ] }, { "cell_type": "markdown", "metadata": { "id": "UVz2dQgzhD4p" }, "source": [ "### Residuality and low rank together make Transformers less like a sequence model and more like a computer (that we can take gradients through)." ] }, { "cell_type": "markdown", "metadata": { "id": "hVlzwR03m8mC" }, "source": [ "The combination of residuality\n", "(changes are added to the current input)\n", "and low rank\n", "(only a small subspace is changed by each head)\n", "drastically changes the intuition about Transformers." ] }, { "cell_type": "markdown", "metadata": { "id": "qqjZI2jKe6HH" }, "source": [ "Rather than being an \"embedding of a token in its context\",\n", "the residual stream becomes something more like a memory or a scratchpad:\n", "one layer reads a small bit of information from the stream\n", "and writes a small bit of information back to it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5YIBkxlqepjc" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/transformer-layer-residual.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RtsKhkLfk00l" }, "source": [ "The residual stream works like a memory because it is roomy enough\n", "that these actions need not interfere:\n", "the subspaces targeted by reads and writes are small relative to the ambient space,\n", "so they can\n", "\n", "Additionally, the dimension of each head is still in the 100s in large models,\n", "and\n", "[high dimensional (>50) vector spaces have many \"almost-orthogonal\" vectors](https://link.springer.com/article/10.1007/s12559-009-9009-8)\n", "in them, so the number of effectively degrees of freedom is\n", "actually larger than the dimension.\n", "This phenomenon allows high-dimensional tensors to serve as\n", "[very large content-addressable associative memories](https://arxiv.org/abs/2008.06996).\n", "There are\n", "[close connections between associative memory addressing algorithms and Transformer attention](https://arxiv.org/abs/2008.02217).\n", "\n", "Together, this means an early layer can write information to the stream\n", "that can be used by later layers -- by many of them at once, possibly much later.\n", "Later layers can learn to edit this information,\n", "e.g. deleting it,\n", "if doing so reduces the loss,\n", "but by default the information is preserved." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EragIygzJg86" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-stream-read-write.png\")" ] }, { "cell_type": "markdown", "metadata": { "id": "oKIaUZjwkpW7" }, "source": [ "Lastly, the softmax in the attention has a sparsifying effect,\n", "and so many attention heads are reading from \n", "just one token and writing to just one other token." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dN6VcJqIMKnB" }, "outputs": [], "source": [ "display.Image(url=base_url + \"/residual-token-to-token.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Repeatedly reading information from an external memory\n", "and using it to decide which operation to perform\n", "and where to write the results\n", "is at the core of the\n", "[Turing machine formalism](https://en.wikipedia.org/wiki/Turing_machine).\n", "For a concrete example, the\n", "[Transformer Circuits work](https://transformer-circuits.pub/2021/framework/index.html)\n", "includes a dissection of a form of \"pointer arithmetic\"\n", "that appears in some models." ] }, { "cell_type": "markdown", "metadata": { "id": "0kLFh7Mvnolr" }, "source": [ "This point of view seems\n", "very promising for explaining numerous\n", "otherwise perhaps counterintuitive features of Transformer models.\n", "\n", "- This framework predicts lots that Transformers will readily copy-and-paste information,\n", "which might explain phenomena like\n", "[incompletely trained Transformers repeating their outputs multiple times](https://youtu.be/SQLm9U0L0zM?t=1030).\n", "\n", "- It also readily explains\n", "[in-context learning behavior](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html),\n", "an important component of why Transformers perform well on medium-length texts\n", "and in few-shot learning.\n", "\n", "- Transformers also perform better on reasoning tasks when the text\n", "[\"let's think step-by-step\"](https://arxiv.org/abs/2205.11916)\n", "is added to their input prompt.\n", "This is partly due to the fact that that prompt is associated,\n", "in the dataset, with clearer reasoning,\n", "and since the models are trained to predict which tokens tend to appear\n", "after an input, they tend to produce better reasoning with that prompt --\n", "an explanation purely in terms of sequence modeling.\n", "But it also gives the Transformer license to generate a large number of tokens\n", "that act to store intermediate information,\n", "making for a richer residual stream\n", "for reading and writing." ] }, { "cell_type": "markdown", "metadata": { "id": "RyLRzgG-93yB" }, "source": [ "### Implementation detail: Transformers are position-insensitive by default." ] }, { "cell_type": "markdown", "metadata": { "id": "oR6PnrlA_hJ2" }, "source": [ "In the attention calculation\n", "each token can query each other token,\n", "with no regard for order.\n", "Furthermore, the construction of queries, keys, and values\n", "is based on the content of the embedding vector,\n", "which does not automatically include its position.\n", "\"dog bites man\" and \"man bites dog\" are identical, as in\n", "[bag-of-words modeling](https://machinelearningmastery.com/gentle-introduction-bag-words-model/).\n", "\n", "For most sequences,\n", "this is unacceptable:\n", "absolute and relative position matter\n", "and we cannot use the future to predict the past.\n", "\n", "We need to add two pieces to get a Transformer architecture that's usable for next-token prediction." ] }, { "cell_type": "markdown", "metadata": { "id": "EWHxGJz2-6ZK" }, "source": [ "First, the simpler piece:\n", "\"causal\" attention,\n", "so-named because it ensures that values earlier in the sequence\n", "are not influenced by later values, which would\n", "[violate causality](https://youtu.be/4xj0KRqzo-0?t=42)." ] }, { "cell_type": "markdown", "metadata": { "id": "0c42xi6URYB4" }, "source": [ "The most common solution is straightforward:\n", "we calculate attention between all tokens,\n", "then throw out non-causal values by \"masking\" them\n", "(this is before applying the softmax,\n", "so masking means adding $-\\infty$).\n", "\n", "This feels wasteful --\n", "why are we calculating values we don't need?\n", "Trying to be smarter would be harder,\n", "and might rely on operations that aren't as optimized as\n", "matrix multiplication and addition.\n", "Furthermore, it's \"only\" twice as many operations,\n", "so it doesn't even show up in $O$-notation.\n", "\n", "A sample attention mask generated by our code base is shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NXaWe6pT-9jV" }, "outputs": [], "source": [ "from text_recognizer.models import transformer_util\n", "\n", "\n", "attention_mask = transformer_util.generate_square_subsequent_mask(100)\n", "\n", "ax = plt.matshow(torch.exp(attention_mask.T)); cb = plt.colorbar(ticks=[0, 1], fraction=0.05)\n", "plt.ylabel(\"Can the embedding at this index\"); plt.xlabel(\"attend to embeddings at this index?\")\n", "print(attention_mask[:10, :10].T); cb.set_ticklabels([False, True]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This solves our causality problem,\n", "but we still don't have positional information." ] }, { "cell_type": "markdown", "metadata": { "id": "ZamUE4WIoGS2" }, "source": [ "The standard technique\n", "is to add alternating sines and cosines\n", "of increasing frequency to the embeddings\n", "(there are\n", "[others](https://direct.mit.edu/coli/article/doi/10.1162/coli_a_00445/111478/Position-Information-in-Transformers-An-Overview),\n", "most notably\n", "[rotary embeddings](https://blog.eleuther.ai/rotary-embeddings/)).\n", "Each position in the sequence is then uniquely identifiable\n", "from the pattern of these values.\n", "\n", "> Furthermore, for the same reason that\n", " [translation-equivariant convolutions are related to Fourier transforms](https://math.stackexchange.com/questions/918345/fourier-transform-as-diagonalization-of-convolution),\n", " translations, e.g. relative positions, are fairly easy to express as linear transformations\n", " of sines and cosines)." ] }, { "cell_type": "markdown", "metadata": { "id": "IDG2uOsaELU0" }, "source": [ "We superimpose this positional information on our embeddings.\n", "Note that because the model is residual,\n", "this position information will be by default preserved\n", "as it passes through the network,\n", "so it doesn't need to be repeatedly added." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what this positional encoding looks like in our codebase:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5Zk62Q-a-1Ax" }, "outputs": [], "source": [ "PositionalEncoder = transformer_util.PositionalEncoding(d_model=50, dropout=0.0, max_len=200)\n", "\n", "pe = PositionalEncoder.pe.squeeze().T[:, :] # placing sequence dimension along the \"x-axis\"\n", "\n", "ax = plt.matshow(pe); plt.colorbar(ticks=[-1, 0, 1], fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Positional Encoding\", y=1.1)\n", "print(pe[:4, :8])" ] }, { "cell_type": "markdown", "metadata": { "id": "ep2ClIWvqDms" }, "source": [ "When we add the positional information to our embeddings,\n", "both the embedding information and the positional information\n", "is approximately preserved,\n", "as can be visually assessed below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PJuFjoCzC0Y4" }, "outputs": [], "source": [ "fake_embeddings = torch.randn_like(pe) * 0.5\n", "\n", "ax = plt.matshow(fake_embeddings); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings Without Positional Encoding\", y=1.1)\n", "\n", "fake_embeddings_with_pe = fake_embeddings + pe\n", "\n", "plt.matshow(fake_embeddings_with_pe); plt.colorbar(ticks=torch.arange(-2, 3), fraction=0.05)\n", "plt.xlabel(\"sequence index\"); plt.ylabel(\"embedding dimension\"); plt.title(\"Embeddings With Positional Encoding\", y=1.1);" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIzBxDkEmH8" }, "source": [ "A [similar technique](https://arxiv.org/abs/2103.06450)\n", "is used to also incorporate positional information into the image embeddings,\n", "which are flattened before being fed to the decoder." ] }, { "cell_type": "markdown", "metadata": { "id": "HC1N85wl8dvn" }, "source": [ "### Learn more about Transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "lJwYxkjTk15t" }, "source": [ "We're only able to give a flavor and an intuition for Transformers here.\n", "\n", "To improve your grasp on the nuts and bolts, check out the\n", "[original \"Attention Is All You Need\" paper](https://arxiv.org/abs/1706.03762),\n", "which is surprisingly approachable,\n", "as far as ML research papers go.\n", "The\n", "[Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)\n", "adds code and commentary to the original paper,\n", "which makes it even more digestible.\n", "For something even friendlier, check out the\n", "[Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)\n", "by Jay Alammar, which has an accompanying\n", "[video](https://youtu.be/-QH8fRhqFHM).\n", "\n", "Anthropic's work on\n", "[Transformer Circuits](https://transformer-circuits.pub/),\n", "summarized above, has some of the best material\n", "for building theoretical understanding\n", "and is still being updated with extensions and applications of the framework.\n", "The\n", "[accompanying exercises](https://transformer-circuits.pub/2021/exercises/index.html)\n", "are a great aid for checking and building your understanding.\n", "\n", "But they are fairly math-heavy.\n", "If you have more of a software engineering background, see\n", "Transformer Circuits co-author Nelson Elhage's blog post\n", "[Transformers for Software Engineers](https://blog.nelhage.com/post/transformers-for-software-engineers/).\n", "\n", "For a gentler introduction to the intuition for Transformers,\n", "check out Brandon Rohrer's\n", "[Transformers From Scratch](https://e2eml.school/transformers.html)\n", "tutorial." ] }, { "cell_type": "markdown", "metadata": { "id": "qg7zntJES-aT" }, "source": [ "An aside:\n", "the matrix multiplications inside attention dominate\n", "the big-$O$ runtime of Transformers.\n", "So trying to make the attention mechanism more efficient, e.g. linear time,\n", "has generated a lot of research\n", "(review paper\n", "[here](https://arxiv.org/abs/2009.06732)).\n", "Despite drawing a lot of attention, so to speak,\n", "at the time of writing in mid-2022, these methods\n", "[haven't been used in large language models](https://twitter.com/MitchellAGordon/status/1545932726775193601),\n", "so it isn't likely to be worth the effort to spend time learning about them\n", "unless you are a Transformer specialist." ] }, { "cell_type": "markdown", "metadata": { "id": "vCjXysEJ8g9_" }, "source": [ "# Using Transformers to read paragraphs of text" ] }, { "cell_type": "markdown", "metadata": { "id": "KsfKWnOvqjva" }, "source": [ "Our simple convolutional model for text recognition from\n", "[Lab 02b](https://fsdl.me/lab02b-colab)\n", "could only handle cleanly-separated characters.\n", "\n", "It worked by sliding a LeNet-style CNN\n", "over the image,\n", "predicting a character for each step." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "njLdzBqy-I90" }, "outputs": [], "source": [ "import text_recognizer.data\n", "\n", "\n", "emnist_lines = text_recognizer.data.EMNISTLines()\n", "line_cnn = text_recognizer.models.LineCNNSimple(emnist_lines.config())\n", "\n", "# for sliding, see the for loop over range(S)\n", "line_cnn.forward??" ] }, { "cell_type": "markdown", "metadata": { "id": "K0N6yDBQq8ns" }, "source": [ "But unfortunately for us, handwritten text\n", "doesn't come in neatly-separated characters\n", "of equal size, so we trained our model on synthetic data\n", "designed to work with that model." ] }, { "cell_type": "markdown", "metadata": { "id": "hiqUVbj0sxLr" }, "source": [ "Now that we have a better model,\n", "we can work with better data:\n", "paragraphs from the\n", "[IAM Handwriting database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database)." ] }, { "cell_type": "markdown", "metadata": { "id": "oizsOAcKs-dD" }, "source": [ "The cell uses our `LightningDataModule`\n", "to download and preprocess this data,\n", "writing results to disk.\n", "We can then spin up `DataLoader`s to give us batches.\n", "\n", "It can take several minutes to run the first time\n", "on commodity machines,\n", "with most time spent extracting the data.\n", "On subsequent runs,\n", "the time-consuming operations will not be repeated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uL9LHbjdsUbm" }, "outputs": [], "source": [ "iam_paragraphs = text_recognizer.data.IAMParagraphs()\n", "\n", "iam_paragraphs.prepare_data()\n", "iam_paragraphs.setup()\n", "xs, ys = next(iter(iam_paragraphs.val_dataloader()))\n", "\n", "iam_paragraphs" ] }, { "cell_type": "markdown", "metadata": { "id": "nBkFN9bbTm_S" }, "source": [ "Now that we've got a batch,\n", "let's take a look at some samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hqaps8yxtBhU" }, "outputs": [], "source": [ "import random\n", "\n", "import numpy as np\n", "import wandb\n", "\n", "\n", "def show(y):\n", " y = y.detach().cpu() # bring back from accelerator if it's being used\n", " return \"\".join(np.array(iam_paragraphs.mapping)[y]).replace(\"

\", \"\")\n", "\n", "idx = random.randint(0, len(xs))\n", "\n", "print(show(ys[idx]))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "4dT3UCNzTsoc" }, "source": [ "The `ResnetTransformer` model can run on this data\n", "if passed the `.config`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WXL-vIGRr86D" }, "outputs": [], "source": [ "import text_recognizer.models\n", "\n", "\n", "rnt = text_recognizer.models.ResnetTransformer(data_config=iam_paragraphs.config())" ] }, { "cell_type": "markdown", "metadata": { "id": "MMxa-oWyT01E" }, "source": [ "Our models are now big enough\n", "that we want to make use of GPU acceleration\n", "as much as we can,\n", "even when working on single inputs,\n", "so let's cast to the GPU if we have one." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YyUM8LgvW0w" }, "outputs": [], "source": [ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "\n", "rnt.to(device); xs = xs.to(device); ys = ys.to(device);" ] }, { "cell_type": "markdown", "metadata": { "id": "Y-E3UdD4zUJi" }, "source": [ "First, let's just pass it through the ResNet encoder." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-LUUtlvaxrvg" }, "outputs": [], "source": [ "resnet_embedding, = rnt.resnet(xs[idx:idx+1].repeat(1, 3, 1, 1))\n", " # resnet is designed for RGB images, so we replicate the input across channels 3 times" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eimgJ5dnywjg" }, "outputs": [], "source": [ "resnet_idx = random.randint(0, len(resnet_embedding)) # re-execute to view a different channel\n", "plt.matshow(resnet_embedding[resnet_idx].detach().cpu(), cmap=\"Greys_r\");\n", "plt.axis(\"off\"); plt.colorbar(fraction=0.05);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These embeddings, though generated by random, untrained weights,\n", "are not entirely useless.\n", "\n", "Before neural networks could be effectively\n", "trained end to end,\n", "they were often used with frozen random weights\n", "eveywhere except the final layer\n", "(see e.g.\n", "[Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network)).\n", "[As late as 2015](https://www.cv-foundation.org/openaccess/content_cvpr_workshops_2015/W13/html/Paisitkriangkrai_Effective_Semantic_Pixel_2015_CVPR_paper.html),\n", "these methods were still competitive, and\n", "[Neural Tangent Kernels](https://arxiv.org/abs/1806.07572)\n", "provide a\n", "[theoretical basis](https://arxiv.org/abs/2011.14522)\n", "for understanding their performance." ] }, { "cell_type": "markdown", "metadata": { "id": "ye6pW0ETzw2A" }, "source": [ "The final result, though, is repetitive gibberish --\n", "at the bare minimum, we need to train the unembedding/readout layer\n", "in order to get reasonable text." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our architecture includes randomization with dropout,\n", "so repeated runs of the cell below will generate different outcomes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xu3Pa7gLsFMo" }, "outputs": [], "source": [ "preds, = rnt(xs[idx:idx+1]) # can take up to two minutes on a CPU. Transformers ❤️ GPUs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gvCXUbskv6XM" }, "outputs": [], "source": [ "print(show(preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Without teacher forcing, runtime is also variable from iteration to iteration --\n", "the model stops when it generates an \"end sequence\" or padding token,\n", "which is not deterministic thanks to the dropout layers.\n", "For similar reasons, runtime is variable across inputs.\n", "\n", "The variable runtime of autoregressive generation\n", "is also not great for scaling.\n", "In a distributed setting, as required for large scale,\n", "forward passes need to be synced across devices,\n", "and if one device is generating a batch of much longer sequences,\n", "it will cause all the others to idle while they wait on it to finish." ] }, { "cell_type": "markdown", "metadata": { "id": "t76MSVRXV0V7" }, "source": [ "Let's turn our model into a `TransformerLitModel`\n", "so we can run with teacher forcing.\n", "\n", "> You may be wondering:\n", " why isn't teacher forcing part of the PyTorch module?\n", " In general, the `LightningModule`\n", " should encapsulate things that are needed in training, validation, and testing\n", " but not during inference.\n", " The teacher forcing trick fits this paradigm,\n", " even though it's so critical to what makes Transformers powerful. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8qrHRKHowdDi" }, "outputs": [], "source": [ "import text_recognizer.lit_models\n", "\n", "lit_rnt = text_recognizer.lit_models.TransformerLitModel(rnt)" ] }, { "cell_type": "markdown", "metadata": { "id": "MlNaFqR50Oid" }, "source": [ "Now we can use `.teacher_forward` if we also provide the target `ys`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lpZdqXS5wn0F" }, "outputs": [], "source": [ "forcing_outs, = lit_rnt.teacher_forward(xs[idx:idx+1], ys[idx:idx+1])" ] }, { "cell_type": "markdown", "metadata": { "id": "0Zx9SmsN0QLT" }, "source": [ "This may not run faster than the `rnt.forward`,\n", "since generations are always the maximum possible length,\n", "but runtimes and output lengths are deterministic and constant." ] }, { "cell_type": "markdown", "metadata": { "id": "tu-XNYpi0Qvi" }, "source": [ "Forcing doesn't necessarily make our predictions better.\n", "They remain highly repetitive gibberish." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JcEgify9w0sv" }, "outputs": [], "source": [ "forcing_preds = torch.argmax(forcing_outs, dim=0)\n", "\n", "print(show(forcing_preds.cpu()))\n", "wandb.Image(xs[idx]).image" ] }, { "cell_type": "markdown", "metadata": { "id": "xn6GGNzc9a3o" }, "source": [ "## Training the `ResNetTransformer`" ] }, { "cell_type": "markdown", "metadata": { "id": "uvZYsuSyWUXe" }, "source": [ "We're finally ready to train this model on full paragraphs of handwritten text!" ] }, { "cell_type": "markdown", "metadata": { "id": "3cJwC7b720Sd" }, "source": [ "This is a more serious model --\n", "it's the one we use in the\n", "[deployed TextRecognizer application](http://fsdl.me/app).\n", "It's much larger than the models we've seen this far,\n", "so it can easily outstrip available compute resources,\n", "in particular GPU memory.\n", "\n", "To help, we use\n", "[automatic mixed precision](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/precision.html),\n", "which shrinks the size of most of our floats by half,\n", "which reduces memory consumption and can speed up computation.\n", "\n", "If your GPU has less than 8GB of available RAM,\n", "you'll see a \"CUDA out of memory\" `RuntimeError`,\n", "which is something of a\n", "[rite of passage in ML](https://twitter.com/Suhail/status/1549555136350982145).\n", "In this case, you can resolve it by reducing the `--batch_size`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w1mXlhfy04Nm" }, "outputs": [], "source": [ "import torch\n", "\n", "gpus = int(torch.cuda.is_available())\n", "\n", "if gpus:\n", " !nvidia-smi\n", "else:\n", " print(\"watch out! working with this model on a typical CPU is not feasible\")" ] }, { "cell_type": "markdown", "metadata": { "id": "os1vW1rPZ1dy" }, "source": [ "Even with an okay GPU, like a\n", "[Tesla P100](https://www.nvidia.com/en-us/data-center/tesla-p100/),\n", "a single epoch of training can take over 10 minutes to run.\n", "We use the `--limit_{train/val/test}_batches` flags to keep the runtime short,\n", "but you can remove those flags to see what full training looks like." ] }, { "cell_type": "markdown", "metadata": { "id": "vnF6dWFn4JlZ" }, "source": [ "It can take a long time (overnight)\n", "to train this model to decent performance on a single GPU,\n", "so we'll focus on other pieces for the exercises.\n", "\n", "> At the time of writing in mid-2022, the cheapest readily available option\n", "for training this model to decent performance on this dataset with this codebase\n", "comes out around $10, using\n", "[the 8xV100 instance on Lambda Labs' GPU Cloud](https://lambdalabs.com/service/gpu-cloud).\n", "See, for example,\n", "[this dashboard](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw)\n", "and associated experiment.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HufjdUZN0t4l", "scrolled": false }, "outputs": [], "source": [ "%%time\n", "# above %%magic times the cell, useful as a poor man's profiler\n", "\n", "%run training/run_experiment.py --data_class IAMParagraphs --model_class ResnetTransformer --loss transformer \\\n", " --gpus={gpus} --batch_size 16 --precision 16 \\\n", " --limit_train_batches 10 --limit_test_batches 1 --limit_val_batches 2" ] }, { "cell_type": "markdown", "metadata": { "id": "L6fQ93ju3Iku" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "udb1Ekjx3L63" }, "source": [ "### 🌟 Try out gradient accumulation and other \"training tricks\"." ] }, { "cell_type": "markdown", "metadata": { "id": "kpqViB4p3Wfb" }, "source": [ "Larger batches are helpful not only for increasing parallelization\n", "and amortizing fixed costs\n", "but also for getting more reliable gradients.\n", "Larger batches give gradients with less noise\n", "and to a point, less gradient noise means faster convergence.\n", "\n", "But larger batches result in larger tensors,\n", "which take up more GPU memory,\n", "a resource that is tightly constrained\n", "and device-dependent.\n", "\n", "Does that mean we are limited in the quality of our gradients\n", "due to our machine size?\n", "\n", "Not entirely:\n", "look up the `--accumulate_grad_batches`\n", "argument to the `pl.Trainer`.\n", "You should be able to understand why\n", "it makes it possible to compute the same gradients\n", "you would find for a batch of size `k * N`\n", "on a machine that can only run batches up to size `N`.\n", "\n", "Accumulating gradients across batches is among the\n", "[advanced training tricks supported by Lightning](https://pytorch-lightning.readthedocs.io/en/1.6.3/advanced/training_tricks.html).\n", "Try some of them out!\n", "Keep the `--limit_{blah}_batches` flags in place so you can quickly experiment." ] }, { "cell_type": "markdown", "metadata": { "id": "b2vtkmX830y3" }, "source": [ "### 🌟🌟 Find the smallest model that can still fit a single batch of 16 examples.\n", "\n", "While training this model to actually fit the whole dataset is infeasible\n", "as a short exercise on commodity hardware,\n", "it's practical to train this model to memorize a batch of 16 examples.\n", "\n", "Passing `--overfit_batches 1` flag limits the number of training batches to 1\n", "and turns off\n", "[`DataLoader` shuffling](https://discuss.pytorch.org/t/how-does-shuffle-in-data-loader-work/49756)\n", "so that in each epoch, the model just sees the same single batch of data over and over again.\n", "\n", "At first, try training the model to a loss of `2.5` --\n", "it should be doable in 100 epochs or less,\n", "which is just a few minutes on a commodity GPU.\n", "\n", "Once you've got that working,\n", "crank up the number of epochs by a factor of 10\n", "and confirm that the loss continues to go down.\n", "\n", "Some tips:\n", "\n", "- Use `--limit_test_batches 0` to turn off testing.\n", "We don't need it because we don't care about generalization\n", "and it's relatively slow because it runs the model autoregressively.\n", "\n", "- Use `--help` and look through the model class args\n", "to find the arguments used to reduce model size.\n", "\n", "- By default, there's lots of regularization to prevent overfitting.\n", "Look through the args for the model class and data class\n", "for regularization knobs to turn off or down." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab03_transformers.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab08/notebooks/lab04_experiments.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 04: Experiment Management" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How experiment management brings observability to ML model development\n", "- Which features of experiment management we use in developing the Text Recognizer\n", "- Workflows for using Weights & Biases in experiment management, including metric logging, artifact versioning, and hyperparameter optimization" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 4\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This lab contains a large number of embedded iframes\n", "that benefit from having a wide window.\n", "The cell below makes the notebook as wide as your browser window\n", "if `full_width` is set to `True`.\n", "Full width is the default behavior in Colab,\n", "so this cell is intended to improve the viewing experience in other Jupyter environments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-04-video-embed\", width=\"50%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Why experiment management?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To understand why we need experiment management for ML development,\n", "let's start by running an experiment.\n", "\n", "We'll train a new model on a new dataset,\n", "using the training script `training/run_experiment.py`\n", "introduced in [Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use a CNN encoder and Transformer decoder, as in\n", "[Lab 03](https://fsdl.me/lab03-colab),\n", "but with some changes so we can iterate faster.\n", "We'll operate on just single lines of text at a time (`--dataclass IAMLines`), as in\n", "[Lab02b](https://fsdl.me/lab02b-colab),\n", "and we'll use a smaller CNN (`--modelclass LineCNNTransformer`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM # base dataset of images of handwritten text\n", "from text_recognizer.data import IAMLines # processed version split into individual lines\n", "from text_recognizer.models import LineCNNTransformer # simple CNN encoder / Transformer decoder\n", "\n", "\n", "print(IAM.__doc__)\n", "\n", "# uncomment a line below for details on either class\n", "# IAMLines?? \n", "# LineCNNTransformer??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below will train a model on 10% of the data for two epochs.\n", "\n", "It takes up to a few minutes to run on commodity hardware,\n", "including data download and preprocessing.\n", "As it's running, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%%time\n", "import torch\n", "\n", "\n", "gpus = int(torch.cuda.is_available()) \n", "\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 2 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 --limit_test_batches 0.1 --log_every_n_steps 10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the model trains, we're calculating lots of metrics --\n", "loss on training and validation, [character error rate](https://torchmetrics.readthedocs.io/en/v0.7.3/references/functional.html#char-error-rate-func) --\n", "and reporting them to the terminal.\n", "\n", "This is achieved by the built-in `.log` method\n", "([docs](https://pytorch-lightning.readthedocs.io/en/1.6.1/common/lightning_module.html#train-epoch-level-metrics))\n", "of the `LightningModule`,\n", "and it is a very straightforward way to get basic information about your experiment as it's running\n", "without leaving the context where you're running it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning to read\n", "[information from streaming numbers in the command line](http://www.quickmeme.com/img/45/4502c7603faf94c0e431761368e9573df164fad15f1bbc27fc03ad493f010dea.jpg)\n", "is something of a rite of passage for MLEs, but\n", "let's consider what we can't see here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We're missing all metric values except the most recent --\n", "we can see them as they stream in, but they're constantly overwritten.\n", "We also can't associate them with timestamps, steps, or epochs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We also don't see any system metrics.\n", "We can't see how much the GPU is being utilized, how much CPU RAM is free, or how saturated our I/O bandwidth is\n", "without launching a separate process.\n", "And even if we do, those values will also not be saved and timestamped,\n", "so we can't correlate them with other things during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- As we continue to run experiments, changing code and opening new terminals,\n", "even the information we have or could figure out now will disappear.\n", "Say you spot a weird error message during training,\n", "but your session ends and the stdout is gone,\n", "so you don't know exactly what it was.\n", "Can you recreate the error?\n", "Which git branch and commit were you on?\n", "Did you have any uncommitted changes? Which arguments did you pass?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Also, model checkpoints containing the parameter values have been saved to disk.\n", "Can we relate these checkpoints to their metrics, both in terms of accuracy and in terms of performance?\n", "As we run more and more experiments,\n", "we'll want to slice and dice them to see if,\n", "say, models with `--lr 0.001` are generally better or worse than models with `--lr 0.0001`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to save and log all of this information, and more, in order to make our model training\n", "[observable](https://docs.honeycomb.io/getting-started/learning-about-observability/) --\n", "in short, so that we can understand, make decisions about, and debug our model training\n", "by looking at logs and source code, without having to recreate it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we had to write the logging code we need to save this information ourselves, that'd put us in for a world of hurt:\n", "1. That's a lot of code that's not at the core of building an ML-powered system. Robustly saving version control information means becoming _very_ good with your VCS, which is less time spent on mastering the important stuff -- your data, your models, and your problem domain.\n", "2. It's very easy to forget to log something that you don't yet realize is going to be critical at some point. Data on network traffic, disk I/O, and GPU/CPU syncing is unimportant until suddenly your training has slowed to a crawl 12 hours into training and you can't figure out where the bottleneck is.\n", "3. Once you do start logging everything that's necessary, you might find it's not performant enough -- the code you wrote so you can debug performance issues is [tanking your performance](https://i.imgflip.com/6q54og.jpg).\n", "4. Just logging is not enough. The bytes of data need to be made legible to humans in a GUI and searchable via an API, or else they'll be too hard to use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Local Experiment Tracking with Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Luckily, we don't have to. PyTorch Lightning integrates with other libraries for additional logging features,\n", "and it makes logging very easy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `.log` method of the `LightningModule` isn't just for logging to the terminal.\n", "\n", "It can also use a logger to push information elsewhere.\n", "\n", "By default, we use\n", "[TensorBoard](https://www.tensorflow.org/tensorboard)\n", "via the Lightning `TensorBoardLogger`,\n", "which has been saving results to the local disk.\n", "\n", "Let's find them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# we use a sequence of bash commands to get the latest experiment's directory\n", "# by hand, you can just copy and paste it from the terminal\n", "\n", "list_all_log_files = \"find training/logs/lightning_logs/\" # find avoids issues ls has with \\n in filenames\n", "filter_to_folders = \"grep '_[0-9]*$'\" # regex match on end of line\n", "sort_version_descending = \"sort -Vr\" # uses \"version\" sorting (-V) and reverses (-r)\n", "take_first = \"head -n 1\" # the first n elements, n=1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latest_log, = ! {list_all_log_files} | {filter_to_folders} | {sort_version_descending} | {take_first}\n", "latest_log" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!ls -lh {latest_log}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, we need to launch a TensorBoard server --\n", "much like we need to launch a Jupyter server to use Jupyter notebooks.\n", "\n", "The cells below load an extension that lets you use TensorBoard inside of a notebook\n", "the same way you'd use it from the command line, and then launch it." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext tensorboard" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# same command works in terminal, with \"{arguments}\" replaced with values or \"$VARIABLES\"\n", "\n", "port = 11717 # pick an open port on your machine\n", "host = \"0.0.0.0\" # allow connections from the internet\n", " # watch out! make sure you turn TensorBoard off\n", "\n", "%tensorboard --logdir {latest_log} --port {port} --host {host}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should see some charts of metrics over time along with some charting controls.\n", "\n", "You can click around in this interface and explore it if you'd like,\n", "but in the next section, we'll see that there are better tools for experiment management." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've run many experiments on this machine,\n", "you can see all of their results by pointing TensorBoard\n", "at the whole `lightning_logs` directory,\n", "rather than just one experiment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%tensorboard --logdir training/logs/lightning_logs --port {port + 1} --host \"0.0.0.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For large numbers of experiments, the management experience is not great --\n", "it's for example hard to go from a line in a chart to metadata about the experiment or metric depicted in that line.\n", "\n", "It's especially difficult to switch between types of experiments, to compare experiments run on different machines, or to collaborate with others,\n", "which are important workflows as applications mature and teams grow." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tensorboard is an independent service, so we need to make sure we turn it off when we're done. Just flip `done_with_tensorboard` to `True`.\n", "\n", "If you run into any issues with the above cells failing to launch,\n", "especially across iterations of this lab, run this cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorboard.manager\n", "\n", "# get the process IDs for all tensorboard instances\n", "pids = [tb.pid for tb in tensorboard.manager.get_all()]\n", "\n", "done_with_tensorboard = False\n", "\n", "if done_with_tensorboard:\n", " # kill processes\n", " for pid in pids:\n", " !kill {pid} 2> /dev/null\n", " \n", " # remove the temporary files that sometimes persist, see https://stackoverflow.com/a/59582163\n", " !rm -rf {tensorboard.manager._get_info_dir()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiment Management with Weights & Biases" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How do we manage experiments when we hit the limits of local TensorBoard?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorBoard is powerful and flexible and very scalable,\n", "but running it requires engineering effort and babysitting --\n", "you're running a database, writing data to it,\n", "and layering a web application over it.\n", "\n", "This is a fairly common workflow for web developers,\n", "but not so much for ML engineers.\n", "\n", "You can avoid this with [tensorboard.dev](https://tensorboard.dev/),\n", "and it's as simple as running the command `tensorboard dev upload`\n", "pointed at your logging directory.\n", "\n", "But there are strict limits to this free service:\n", "1GB of tensor data and 1GB of binary data.\n", "A single Text Recognizer model checkpoint is ~100MB,\n", "and that's not particularly large for a useful model.\n", "\n", "Furthermore, all data is public,\n", "so if you upload the inputs and outputs of your model,\n", "anyone who finds the link can see them.\n", "\n", "Overall, tensorboard.dev works very well for certain academic and open projects\n", "but not for industrial ML." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To avoid that narrow permissions and limits issue,\n", "you could use [git LFS](https://git-lfs.github.com/)\n", "to track the binary data and tensor data,\n", "which is more likely to be sensitive than metrics.\n", "\n", "The Hugging Face ecosystem uses TensorBoard and git LFS.\n", "\n", "It includes the Hugging Face Hub, a git server much like GitHub,\n", "but designed first and foremost for collaboration on models and datasets,\n", "rather than collaboration on code.\n", "For example, the Hugging Face Hub\n", "[will host TensorBoard alongside models](https://huggingface.co/docs/hub/tensorboard)\n", "and officially has\n", "[no storage limit](https://discuss.huggingface.co/t/is-there-a-size-limit-for-dataset-hosting/14861/4),\n", "avoiding the\n", "[bandwidth and storage pricing](https://docs.github.com/en/repositories/working-with-files/managing-large-files/about-storage-and-bandwidth-usage)\n", "that make using git LFS with GitHub expensive.\n", "\n", "However, we prefer to avoid mixing software version control and experiment management.\n", "\n", "First, using the Hub requires maintaining an additional git remote,\n", "which is a hard ask for many engineering teams.\n", "\n", "Secondly, git-style versioning is an awkward fit for logging --\n", "is it really sensible to create a new commit for each logging event while you're watching live?\n", "\n", "Instead, we prefer to use systems that solve experiment management with _databases_." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are multiple alternatives to TensorBoard + git LFS that fit this bill.\n", "The primary [open governance](https://www.ibm.com/blogs/cloud-computing/2016/10/27/open-source-open-governance/)\n", "tool is [MLflow](https://github.com/mlflow/mlflow/)\n", "and there are a number of\n", "[closed-governance and/or closed-source tools](https://www.reddit.com/r/MachineLearning/comments/q5g7m9/n_sagemaker_experiments_vs_comet_neptune_wandb_etc/).\n", "\n", "These tools generally avoid any need to worry about hosting\n", "(unless data governance rules require a self-hosted version).\n", "\n", "For a sampling of publicly-posted opinions on experiment management tools,\n", "see these discussions from Reddit:\n", "\n", "- r/mlops: [1](https://www.reddit.com/r/mlops/comments/uxieq3/is_weights_and_biases_worth_the_money/), [2](https://www.reddit.com/r/mlops/comments/sbtkxz/best_mlops_platform_for_2022/)\n", "- r/MachineLearning: [3](https://www.reddit.com/r/MachineLearning/comments/sqa36p/comment/hwls9px/?utm_source=share&utm_medium=web2x&context=3)\n", "\n", "Among these tools, the FSDL recommendation is\n", "[Weights & Biases](https://wandb.ai),\n", "which we believe offers\n", "- the best user experience, both in the Python SDKs and in the graphical interface\n", "- the best integrations with other tools,\n", "including\n", "[Lightning](https://docs.wandb.ai/guides/integrations/lightning) and\n", "[Keras](https://docs.wandb.ai/guides/integrations/keras),\n", "[Jupyter](https://docs.wandb.ai/guides/track/jupyter),\n", "and even\n", "[TensorBoard](https://docs.wandb.ai/guides/integrations/tensorboard),\n", "and\n", "- the best tools for collaboration.\n", "\n", "Below, we'll take care to point out which logging and management features\n", "are available via generic interfaces in Lightning and which are W&B-specific." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import wandb\n", "\n", "print(wandb.__doc__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding it to our experiment running code is extremely easy,\n", "relative to the features we get, which is\n", "one of the main selling points of W&B.\n", "\n", "We get most of our new experiment management features just by changing a single variable, `logger`, from\n", "`TensorboardLogger` to `WandbLogger`\n", "and adding two lines of code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"args.wandb\" -A 5 training/run_experiment.py | head -n 6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see what each of these lines does for us below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this logger is built into and maintained by PyTorch Lightning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning.loggers import WandbLogger\n", "\n", "\n", "WandbLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to complete the rest of this notebook,\n", "you'll need a Weights & Biases account.\n", "\n", "As with GitHub the free tier, for personal, academic, and open source work,\n", "is very generous.\n", "\n", "The Text Recognizer project will fit comfortably within the free tier.\n", "\n", "Run the cell below and follow the prompts to log in or create an account or go\n", "[here](https://wandb.ai/signup)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb login" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the cell below to launch an experiment tracked with Weights & Biases.\n", "\n", "The experiment can take between 3 and 10 minutes to run.\n", "In that time, continue reading below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 10 \\\n", " --log_every_n_steps 10 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1\n", " \n", "last_expt = wandb.run\n", "\n", "wandb.finish() # necessary in this style of in-notebook experiment running, not necessary in CLI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see some new things in our output.\n", "\n", "For example, there's a note from `wandb` that the data is saved locally\n", "and also synced to their servers.\n", "\n", "There's a link to a webpage for viewing the logged data and a name for our experiment --\n", "something like `dandy-sunset-1`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The local logging and cloud syncing happens with minimal impact on performance,\n", "because `wandb` launches a separate process to listen for events and upload them.\n", "\n", "That's a table-stakes feature for a logging framework but not a pleasant thing to write in Python yourself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Runs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To view results, head to the link in the notebook output\n", "that looks like \"Syncing run **{adjective}-{noun}-{number}**\".\n", "\n", "There's no need to wait for training to finish.\n", "\n", "The next sections describe the contents of that interface. You can read them while looking at the W&B interface in a separate tab or window." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For even more convenience, once training is finished we can also see the results directly in the notebook by embedding the webpage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(last_expt.url)\n", "IFrame(last_expt.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have landed on the run page\n", "([docs](https://docs.wandb.ai/ref/app/pages/run-page)),\n", "which collects up all of the information for a single experiment into a collection of tabs.\n", "\n", "We'll work through these tabs from top to bottom.\n", "\n", "Each header is also a link to the documentation for a tab." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Overview tab](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab)\n", "This tab has an icon that looks like `(i)` or 🛈.\n", "\n", "The top section of this tab has high-level information about our run:\n", "- Timing information, like start time and duration\n", "- System hardware, hostname, and basic environment info\n", "- Git repository link and state\n", "\n", "This information is collected and logged automatically.\n", "\n", "The section at the bottom contains configuration information, which here includes all CLI args or their defaults,\n", "and summary metrics.\n", "\n", "Configuration information is collected with `.log_hyperparams` in Lightning or `wandb.config` otherwise." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Charts tab](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab)\n", "\n", "This tab has a line plot icon, something like 📈.\n", "\n", "It's also the default page you land on when looking at a W&B run.\n", "\n", "Charts are generated for everything we `.log` from PyTorch Lightning. The charts here are interactive and editable, and changes persist.\n", "\n", "Unfurl the \"Gradients\" section in this tab to check out the gradient histograms. These histograms can be useful for debugging training instability issues.\n", "\n", "We were able to log these just by calling `wandb.watch` on our model. This is a W&B-specific feature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [System tab](https://docs.wandb.ai/ref/app/pages/run-page#system-tab)\n", "This tab has computer chip icon.\n", "\n", "It contains\n", "- GPU metrics for all GPUs: temperature, [utilization](https://stackoverflow.com/questions/5086814/how-is-gpu-and-memory-utilization-defined-in-nvidia-smi-results), and memory allocation\n", "- CPU metrics: memory usage, utilization, thread counts\n", "- Disk and network I/O levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Model tab](https://docs.wandb.ai/ref/app/pages/run-page#model-tab)\n", "This tab has an undirected graph icon that looks suspiciously like a [pawnbrokers' symbol](https://en.wikipedia.org/wiki/Pawnbroker#:~:text=The%20pawnbrokers%27%20symbol%20is%20three,the%20name%20of%20Lombard%20banking.).\n", "\n", "The information here was also generated from `wandb.watch`, and includes parameter counts and input/output shapes for all layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Logs tab](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab)\n", "This tab has an icon that looks like a stylized command prompt, `>_`.\n", "\n", "It contains information that was printed to the stdout.\n", "\n", "This tab is useful for, e.g., determining when exactly a warning or error message started appearing.\n", "\n", "Note that model summary information is printed here. We achieve this with a Lightning `Callback` called `ModelSummary`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelSummary\" training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lightning `Callback`s add extra \"nice-to-have\" engineering features to our model training.\n", "\n", "For more on Lightning `Callback`s, see\n", "[Lab 02a](https://fsdl.me/lab02a-colab)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Files tab](https://docs.wandb.ai/ref/app/pages/run-page#files-tab)\n", "This tab has a stylized document icon, something like 📄.\n", "\n", "You can use this tab to view any files saved with the `wandb.save`.\n", "\n", "For most uses, that style is deprecated in favor of `wandb.log_artifact`,\n", "which we'll discuss shortly.\n", "\n", "But a few pieces of information automatically collected by W&B end up in this tab.\n", "\n", "Some highlights:\n", " - Much more detailed environment info: `conda-environment.yaml` and `requirements.txt`\n", " - A `diff.patch` that represents the difference between the files in the `git` commit logged in the overview and the actual disk state." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [Artifacts tab](https://docs.wandb.ai/ref/app/pages/run-page#artifacts-tab)\n", "This tab has the database or [drum memory icon](https://stackoverflow.com/a/2822750), which looks like a cylinder of three stacked hockey pucks.\n", "\n", "This tab contains all of the versioned binary files, aka artifacts, associated with our run.\n", "\n", "We store two kinds of binary files\n", " - `run_table`s of model inputs and outputs\n", " - `model` checkpoints\n", "\n", "We get model checkpoints via the built-in Lightning `ModelCheckpoint` callback, which is not specific to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!grep \"callbacks.ModelCheckpoint\" -A 9 training/run_experiment.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The tools for working with artifacts in W&B are powerful and complex, so we'll cover them in various places throughout this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interactive Tables of Logged Media" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Returning to the Charts tab,\n", "notice that we have model inputs and outputs logged in structured tables\n", "under the train, validation, and test sections.\n", "\n", "These tables are interactive as well\n", "([docs](https://docs.wandb.ai/guides/data-vis/log-tables)).\n", "They support basic exploratory data analysis and are compatible with W&B's collaboration features." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to charts in our run page, these tables also have their own pages inside the W&B web app." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "table_versions_url = last_expt.url.split(\"runs\")[0] + f\"artifacts/run_table/run-{last_expt.id}-trainpredictions/\"\n", "table_data_url = table_versions_url + \"v0/files/train/predictions.table.json\"\n", "\n", "print(table_data_url)\n", "IFrame(src=table_data_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting this to work requires more effort and more W&B-specific code\n", "than the other features we've seen so far.\n", "\n", "We'll briefly explain the implementation here, for those who are interested.\n", "\n", "We use a custom Lightning `Callback`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.callbacks.imtotext import ImageToTextTableLogger\n", "\n", "\n", "ImageToTextTableLogger??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, Lightning returns logged information on every batch and these outputs are accumulated throughout an epoch.\n", "\n", "The values are then aggregated with a frequency determined by the `pl.Trainer` argument `--log_every_n_batches`.\n", "\n", "This behavior is sensible for metrics, which are low overhead, but not so much for media,\n", "where we'd rather subsample and avoid holding on to too much information.\n", "\n", "So we additionally control when media is included in the outputs with methods like `add_on_logged_batches`.\n", "\n", "The frequency of media logging is then controlled with `--log_every_n_batches`, as with aggregate metric reporting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from text_recognizer.lit_models.base import BaseImageToTextLitModel\n", "\n", "BaseImageToTextLitModel.add_on_logged_batches??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Projects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Everything we've seen so far has been related to a single run or experiment.\n", "\n", "Experiment management starts to shine when you can organize, filter, and group many experiments at once.\n", "\n", "We organize our runs into \"projects\" and view them on the W&B \"project page\" \n", "([docs](https://docs.wandb.ai/ref/app/pages/project-page)).\n", "\n", "By default in the Lightning integration, the project name is determined based on directory information.\n", "This default can be over-ridden in the code when creating a `WandbLogger`,\n", "but we find it easier to change it from the command line by setting the `WANDB_PROJECT` environment variable." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see what the project page looks like for a longer-running project with lots of experiments.\n", "\n", "The cell below pulls up the project page for some of the debugging and feature addition work done while updating the course from 2021 to 2022." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "project_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/workspace\"\n", "\n", "print(project_url)\n", "IFrame(src=project_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This page and these charts have been customized -- filtering down to the most interesting training runs and surfacing the most important high-level information about them.\n", "\n", "We welcome you to poke around in this interface: deactivate or change the filters, clicking through into individual runs, and change the charts around." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Artifacts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Beyond logging metrics and metadata from runs,\n", "we can also log and version large binary files, or artifacts, and their metadata ([docs](https://docs.wandb.ai/guides/artifacts/artifacts-core-concepts))." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below pulls up all of the artifacts associated with the experiment we just ran." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "IFrame(src=last_expt.url + \"/artifacts\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Click on one of the `model` checkpoints -- the specific version doesn't matter.\n", "\n", "There are a number of tabs here.\n", "\n", "The \"Overview\" tab includes automatically generated metadata, like which run by which user created this model checkpoint, when, and how much disk space it takes up.\n", "\n", "The \"Metadata\" tab includes configurable metadata, here hyperparameters and metrics like `validation/cer`,\n", "which are added by default by the `WandbLogger`.\n", "\n", "The \"Files\" tab contains the actual file contents of the artifact.\n", "\n", "On the left-hand side of the page, you'll see the other versions of the model checkpoint,\n", "including some versions that are \"tagged\" with version aliases, like `latest` or `best`.\n", "\n", "You can click on these to explore the different versions and even directly compare them.\n", "\n", "If you're particularly interested in this tool, try comparing two versions of the `validation-predictions` artifact, starting from the Files tab and clicking inside it to `validation/predictions.table.json`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Artifact storage is part of the W&B free tier.\n", "\n", "The storage limits, as of August 2022, cover 100GB of Artifacts and experiment data.\n", "\n", "The former is sufficient to store ~700 model checkpoints for the Text Recognizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can track your data storage and compare it to your limits at this URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "storage_tracker_url = f\"https://wandb.ai/usage/{last_expt.entity}\"\n", "\n", "print(storage_tracker_url)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Programmatic Access" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also programmatically access our data and metadata via the `wandb` API\n", "([docs](https://docs.wandb.ai/guides/track/public-api-guide)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wb_api = wandb.Api()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For example, we can access the metrics we just logged as a `pandas.DataFrame` by grabbing the run via the API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run = wb_api.run(\"/\".join( # fetch a run given\n", " [last_expt.entity, # the user or org it was logged to\n", " last_expt.project, # the \"project\", usually one of several per repo/application\n", " last_expt.id] # and a unique ID\n", "))\n", "\n", "hist = run.history() # and pull down a sample of the data as a pandas DataFrame\n", "\n", "hist.head(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hist.groupby(\"epoch\")[\"train/loss\"].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this includes the artifacts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# which artifacts where created and logged?\n", "artifacts = run.logged_artifacts()\n", "\n", "for artifact in artifacts:\n", " print(f\"artifact of type {artifact.type}: {artifact.name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thanks to our `ImageToTextTableLogger`,\n", "we can easily recreate training or validation data that came out of our `DataLoader`s,\n", "which is normally ephemeral:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "artifact = wb_api.artifact(f\"{last_expt.entity}/{last_expt.project}/run-{last_expt.id}-trainpredictions:latest\")\n", "artifact_dir = Path(artifact.download(root=\"training/logs\"))\n", "image_dir = artifact_dir / \"media\" / \"images\"\n", "\n", "images = [path for path in image_dir.iterdir()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "from IPython.display import Image\n", "\n", "Image(str(random.choice(images)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Advanced W&B API Usage: MLOps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the strengths of a well-instrumented experiment tracking system is that it allows\n", "automatic relation of information:\n", "what were the inputs when this model's gradient spiked?\n", "Which models have been trained on this dataset,\n", "and what was their performance?\n", "\n", "Having access and automation around this information is necessary for \"MLOps\",\n", "which applies contemporary DevOps principles to ML projects." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cells below pull down the training data\n", "for the model currently running the FSDL Text Recognizer app.\n", "\n", "This is just intended as a demonstration of what's possible,\n", "so don't worry about understanding every piece of this,\n", "and feel free to skip past it.\n", "\n", "MLOps is still a nascent field, and these tools and workflows are likely to change.\n", "\n", "For example, just before the course launched, W&B released a\n", "[Model Registry layer](https://docs.wandb.ai/guides/models)\n", "on top of artifact logging that aims to improve the developer experience for these workflows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We start from the same project we looked at in the project view:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_recognizer_project = wb_api.project(\"fsdl-text-recognizer-2021-training\", entity=\"cfrye59\")\n", "\n", "text_recognizer_project " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we search it for the text recognizer model currently being used in production:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# collect all versions of the text-recognizer ever put into production by...\n", "\n", "for art_type in text_recognizer_project.artifacts_types(): # looking through all artifact types\n", " if art_type.name == \"prod-ready\": # for the prod-ready type\n", " # and grabbing the text-recognizer\n", " production_text_recognizers = art_type.collection(\"paragraph-text-recognizer\").versions()\n", "\n", "# and then get the one that's currently being tested in CI by...\n", "for text_recognizer in production_text_recognizers:\n", " if \"ci-test\" in text_recognizer.aliases: # looking for the one that's labeled as CI-tested\n", " in_prod_text_recognizer = text_recognizer\n", "\n", "# view its metadata at the url or in the notebook\n", "in_prod_text_recognizer_url = text_recognizer_project.url[:-9] + f\"artifacts/{in_prod_text_recognizer.type}/{in_prod_text_recognizer.name.replace(':', '/')}\"\n", "\n", "print(in_prod_text_recognizer_url)\n", "IFrame(src=in_prod_text_recognizer_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From its metadata, we can get information about how it was \"staged\" to be put into production,\n", "and in particular which model checkpoint was used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "staging_run = in_prod_text_recognizer.logged_by()\n", "\n", "training_ckpt, = [at for at in staging_run.used_artifacts() if at.type == \"model\"]\n", "training_ckpt.name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That checkpoint was logged by a training experiment, which is available as metadata.\n", "\n", "We can look at the training run for that model, either here in the notebook or at its URL:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "training_run = training_ckpt.logged_by()\n", "print(training_run.url)\n", "IFrame(src=training_run.url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And from there, we can access logs and metadata about training,\n", "confident that we are working with the model that is actually in production.\n", "\n", "For example, we can pull down the data we logged and analyze it locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training_results = training_run.history(samples=10000)\n", "training_results.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ax = training_results.groupby(\"epoch\")[\"train/loss\"].mean().plot();\n", "training_results[\"validation/loss\"].dropna().plot(logy=True); ax.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 10\n", "training_results[\"validation/loss\"].dropna().iloc[10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reports" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The charts and webpages in Weights & Biases\n", "are substantially more useful than ephemeral stdouts or raw logs on disk.\n", "\n", "If you're spun up on the project,\n", "they accelerate debugging, exploration, and discovery.\n", "\n", "If not, they're not so much useful as they are overwhelming.\n", "\n", "We need to synthesize the raw logged data into information.\n", "This helps us communicate our work with other stakeholders,\n", "preserve knowledge and prevent repetition of work,\n", "and surface insights faster.\n", "\n", "These workflows are supported by the W&B Reports feature\n", "([docs here](https://docs.wandb.ai/guides/reports)),\n", "which mix W&B charts and tables with explanatory markdown text and embeds.\n", "\n", "Below are some common report patterns and\n", "use cases and examples of each." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the examples are from the FSDL Text Recognizer project.\n", "You can find more of them\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/-Report-of-Reports---VmlldzoyMjEwNDM5),\n", "where we've organized them into a report!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dashboard Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dashboards are a structured subset of the output from one or more experiments,\n", "designed for quickly surfacing issues or insights,\n", "like an accuracy or performance regression\n", "or a change in the data distribution.\n", "\n", "Use cases:\n", "- show the basic state of ongoing experiment\n", "- compare one experiment to another\n", "- select the most important charts so you can spin back up into context on a project more quickly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dashboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n", "\n", "IFrame(src=dashboard_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pull Request Documentation Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In most software codebases,\n", "pull requests are a key focal point\n", "for units of work that combine\n", "short-term communication and long-term information tracking.\n", "\n", "In ML codebases, it's more difficult to bring\n", "sufficient information together to make PRs as useful.\n", "At FSDL, we like to add documentary\n", "reports with one or a small number of charts\n", "that connect logged information in the experiment management system\n", "to state in the version control software.\n", "\n", "Use cases:\n", "- communication of results within a team, e.g. code review\n", "- record-keeping that links pull request pages to raw logged info and makes it discoverable\n", "- improving confidence in PR correctness" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bugfix_doc_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Overfit-Check-After-Refactor--VmlldzoyMDY5MjI1\"\n", "\n", "IFrame(src=bugfix_doc_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Blog Post Report" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With sufficient effort, the logged data in the experiment management system\n", "can be made clear enough to be consumed,\n", "sufficiently contextualized to be useful outside the team, and\n", "even beautiful.\n", "\n", "The result is a report that's closer to a blog post than a dashboard or internal document.\n", "\n", "Use cases:\n", "- communication between teams or vertically in large organizations\n", "- external technical communication for branding and recruiting\n", "- attracting users or contributors\n", "\n", "Check out this example, from the Craiyon.ai / DALL·E Mini project, by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma)\n", "and others:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dalle_mini_blog_url = \"https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA#training-dall-e-mini\"\n", "\n", "IFrame(src=dalle_mini_blog_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter Optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many of our choices, like the depth of our network, the nonlinearities of our layers,\n", "and the learning rate and other parameters of our optimizer, cannot be\n", "([easily](https://arxiv.org/abs/1606.04474))\n", "chosen by descent of the gradient of a loss function.\n", "\n", "But these parameters that impact the values of the parameters\n", "we directly optimize with gradients, or _hyperparameters_,\n", "can still be optimized,\n", "essentially by trying options and selecting the values that worked best.\n", "\n", "In general, you can attain much of the benefit of hyperparameter optimization with minimal effort.\n", "\n", "Expending more compute can squeeze small amounts of additional validation or test performance\n", "that makes for impressive results on leaderboards but typically doesn't translate\n", "into better user experience.\n", "\n", "In general, the FSDL recommendation is to use the hyperparameter optimization workflows\n", "built into your other tooling.\n", "\n", "Weights & Biases makes the most straightforward forms of hyperparameter optimization trivially easy\n", "([docs](https://docs.wandb.ai/guides/sweeps)).\n", "\n", "It also supports a number of more advanced tools, like\n", "[Hyperband](https://docs.wandb.ai/guides/sweeps/configuration#early_terminate)\n", "for early termination of poorly-performing runs.\n", "\n", "We can use the same training script and we don't need to run an optimization server.\n", "\n", "We just need to write a configuration yaml file\n", "([docs](https://docs.wandb.ai/guides/sweeps/configuration)),\n", "like the one below." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training/simple-overfit-sweep.yaml\n", "# first we specify what we're sweeping\n", "# we specify a program to run\n", "program: training/run_experiment.py\n", "# we optionally specify how to run it, including setting default arguments\n", "command: \n", " - ${env}\n", " - ${interpreter}\n", " - ${program}\n", " - \"--wandb\"\n", " - \"--overfit_batches\"\n", " - \"1\"\n", " - \"--log_every_n_steps\"\n", " - \"25\"\n", " - \"--max_epochs\"\n", " - \"100\"\n", " - \"--limit_test_batches\"\n", " - \"0\"\n", " - ${args} # these arguments come from the sweep parameters below\n", "\n", "# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it\n", "method: random # generally, random searches perform well, can also be \"grid\" or \"bayes\"\n", "metric:\n", " name: train/loss\n", " goal: minimize\n", "parameters: \n", " # LineCNN hyperparameters\n", " window_width:\n", " values: [8, 16, 32, 64]\n", " window_stride:\n", " values: [4, 8, 16, 32]\n", " # Transformer hyperparameters\n", " tf_layers:\n", " values: [1, 2, 4, 8]\n", " # we can also fix some values, just like we set default arguments\n", " gpus:\n", " value: 1\n", " model_class:\n", " value: LineCNNTransformer\n", " data_class:\n", " value: IAMLines\n", " loss:\n", " value: transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the config we launch a \"controller\":\n", "a lightweight process that just decides what hyperparameters to try next\n", "and coordinates the heavierweight training.\n", "\n", "This lives on the W&B servers, so there are no headaches about opening ports for communication,\n", "cleaning up when it's done, etc." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wandb sweep training/simple-overfit-sweep.yaml --project fsdl-line-recognizer-2022\n", "simple_sweep_id = wb_api.project(\"fsdl-line-recognizer-2022\").sweeps()[0].id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then we can launch an \"agent\" to follow the orders of the controller:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "\n", "# interrupt twice to terminate this cell if it's running too long,\n", "# it can be over 15 minutes with some hyperparameters\n", "\n", "!wandb agent --project fsdl-line-recognizer-2022 --entity {wb_api.default_entity} --count=1 {simple_sweep_id}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above cell runs only a single experiment, because we provided the `--count` argument with a value of `1`.\n", "\n", "If not provided, the agent will run forever for random or Bayesian sweeps\n", "or until the sweep is terminated, which can be done from the W&B interface." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The agents make for a slick workflow for distributing sweeps across GPUs.\n", "\n", "We can just change the `CUDA_VISIBLE_DEVICES` environment variable,\n", "which controls which GPUs are accessible by a process, to launch\n", "parallel agents on separate GPUs on the same machine." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "CUDA_VISIBLE_DEVICES=0 wandb agent $SWEEP_ID\n", "# open another terminal\n", "CUDA_VISIBLE_DEVICES=1 wandb agent $SWEEP_ID\n", "# and so on\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include optional exercises with the labs for learners who want to dive deeper on specific topics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟Contribute to a hyperparameter search." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've kicked off a big hyperparameter search on the `LineCNNTransformer` that anyone can join!\n", "\n", "There are ~10,000,000 potential hyperparameter combinations,\n", "and each takes 30 minutes to test,\n", "so checking each possibility will take over 500 years of compute time.\n", "Best get cracking then!\n", "\n", "Run the cell below to pull up a dashboard and print the URL where you can check on the current status." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_entity = \"fullstackdeeplearning\"\n", "sweep_project = \"fsdl-line-recognizer-2022\"\n", "sweep_id = \"e0eo43eu\"\n", "sweep_url = f\"https://wandb.ai/{sweep_entity}/{sweep_project}/sweeps/{sweep_id}\"\n", "\n", "print(sweep_url)\n", "IFrame(src=sweep_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also retrieve information about the sweep from the API,\n", "including the hyperparameters being swept over." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sweep_info = wb_api.sweep(\"/\".join([sweep_entity, sweep_project, sweep_id]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparams = sweep_info.config[\"parameters\"]\n", "hyperparams" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you'd like to contribute to this sweep,\n", "run the cell below after changing the count to a number greater than 0.\n", "\n", "Each iteration runs for 30 minutes if it does not crash,\n", "e.g. due to out-of-memory errors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "count = 0 # off by default, increase it to join in!\n", "\n", "if count:\n", " !wandb agent {sweep_id} --entity {sweep_entity} --project {sweep_project} --count {count}" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Write some manual logging in `wandb`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the FSDL Text Recognizer codebase,\n", "we almost exclusively log to W&B through Lightning,\n", "rather than through the `wandb` Python SDK.\n", "\n", "If you're interested in learning how to use W&B directly, e.g. with another training framework,\n", "try out this quick exercise that introduces the key players in the SDK." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cell below starts a run with `wandb.init` and provides configuration hyperparameters with `wandb.config`.\n", "\n", "It also calculates a `loss` value and saves a text file, `logs/hello.txt`.\n", "\n", "Add W&B metric and artifact logging to this cell:\n", "- use [`wandb.log`](https://docs.wandb.ai/guides/track/log) to log the loss on each step\n", "- use [`wandb.log_artifact`](https://docs.wandb.ai/guides/artifacts) to save `logs/hello.txt` in an artifact with the name `hello` and whatever type you wish" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import random\n", "\n", "import wandb\n", "\n", "\n", "os.makedirs(\"logs\", exist_ok=True)\n", "\n", "project = \"trying-wandb\"\n", "config = {\"steps\": 50}\n", "\n", "\n", "with wandb.init(project=project, config=config) as run:\n", " steps = wandb.config[\"steps\"]\n", " \n", " for ii in range(steps):\n", " loss = math.exp(-ii) + random.random() / (ii + 1) # ML means making the loss go down\n", " \n", " with open(\"logs/hello.txt\", \"w\") as f:\n", " f.write(\"hello from wandb, my dudes!\")\n", " \n", " run_id = run.id" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you've correctly completed the exercise, the cell below will print only 🥞 emojis and no 🥲s before opening the run in an iframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hello_run = wb_api.run(f\"{project}/{run_id}\")\n", "\n", "# check for logged loss data\n", "if \"loss\" not in hello_run.history().keys():\n", " print(\"loss not logged 🥲\")\n", "else:\n", " print(\"loss logged successfully 🥞\")\n", " if len(hello_run.history()[\"loss\"]) != steps:\n", " print(\"loss not logged on all steps 🥲\")\n", " else:\n", " print(\"loss logged on all steps 🥞\")\n", "\n", "artifacts = hello_run.logged_artifacts()\n", "\n", "# check for artifact with the right name\n", "if \"hello:v0\" not in [artifact.name for artifact in artifacts]:\n", " print(\"hello artifact not logged 🥲\")\n", "else:\n", " print(\"hello artifact logged successfully 🥞\")\n", " # check for the file inside the artifacts\n", " if \"hello.txt\" not in sum([list(artifact.manifest.entries.keys()) for artifact in artifacts], []):\n", " print(\"could not find hello.txt 🥲\")\n", " else:\n", " print(\"hello.txt logged successfully 🥞\")\n", " \n", " \n", "hello_run" ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Find good hyperparameters for the `LineCNNTransformer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default hyperparameters for the `LineCNNTransformer` are not particularly carefully tuned." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Try and find some better hyperparameters: choices that achieve a lower loss on the full dataset faster." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you observe interesting phenomena during training,\n", "from promising hyperparameter combos to software bugs to strange model behavior,\n", "turn the charts into a W&B report and share it with the FSDL community or\n", "[open an issue on GitHub](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/issues)\n", "with a link to them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# check the sweep_info.config above to see the model and data hyperparameters\n", "# read through the --help output for all potential arguments\n", "%run training/run_experiment.py --model_class LineCNNTransformer --data_class IAMLines \\\n", " --loss transformer --batch_size 32 --gpus {gpus} --max_epochs 5 \\\n", " --log_every_n_steps 50 --wandb --limit_test_batches 0.1 \\\n", " --limit_train_batches 0.1 --limit_val_batches 0.1 \\\n", " --help # remove this line to run an experiment instead of printing help\n", " \n", "last_hyperparam_expt = wandb.run # in case you want to pull URLs, look up in API, etc., as in code above\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🌟🌟🌟 Add logging of tensor statistics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to logging model inputs and outputs as human-interpretable media,\n", "it's also frequently useful to see information about their numerical values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're interested in learning more about metric calculation and logging with Lightning,\n", "use [`torchmetrics`](https://torchmetrics.readthedocs.io/en/v0.7.3/)\n", "to add tensor statistic logging to the `LineCNNTransformer`.\n", "\n", "`torchmetrics` comes with built in statistical metrics, like `MinMetric`, `MaxMetric`, and `MeanMetric`.\n", "\n", "All three are useful, but start by adding just one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use your metric with `training/run_experiment.py`, you'll need to open and edit the `text_recognizer/lit_model/base.py` and `text_recognizer/lit_model/transformer.py` files\n", "- Add the metrics to the `BaseImageToTextLitModel`'s `__init__` method, around where `CharacterErrorRate` appears.\n", " - You'll also need to decide whether to calculate separate train/validation/test versions. Whatever you do, start by implementing just one.\n", "- In the appropriate `_step` methods of the `TransformerLitModel`, add metric calculation and logging for `Min`, `Max`, and/or `Mean`.\n", " - Base your code on the calculation and logging of the `val_cer` metric.\n", " - `sync_dist=True` is only important in distributed training settings, so you might not notice any issues regardless of that argument's value." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For an extra challenge, use `MeanSquaredError` to implement a `VarianceMetric`. _Hint_: one way is to use `torch.zeros_like` and `torch.mean`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyMKpeodqRUzgu0VjkCVMBeJ", "collapsed_sections": [], "name": "lab04_experiments.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab08/notebooks/lab05_troubleshooting.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 05: Troubleshooting & Testing" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- Practices and tools for testing and linting Python code in general: `black`, `flake8`, `precommit`, `pytests` and `doctests`\n", "- How to implement tests for ML training systems in particular\n", "- What a PyTorch training step looks like under the hood and how to troubleshoot performance bottlenecks" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 5\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # allow \"hot-reloading\" of modules\n", " %load_ext autoreload\n", " %autoreload 2\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sThWeTtV6fL_" }, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IFrame(src=\"https://fsdl.me/2022-lab-05-video-embed\", width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "xFP8lU4nSg1P" }, "source": [ "# Linting Python and Shell Scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "cXbdYfFlPhZ-" }, "source": [ "### Automatically linting with `pre-commit`" ] }, { "cell_type": "markdown", "metadata": { "id": "ysqqb2GjvLrz" }, "source": [ "We want keep our code clean and uniform across developers\n", "and time.\n", "\n", "Applying the cleanliness checks and style rules should be\n", "as painless and automatic as possible.\n", "\n", "For this purpose, we recommend bundling linting tools together\n", "and enforcing them on all commits with\n", "[`pre-commit`](https://pre-commit.com/)." ] }, { "cell_type": "markdown", "metadata": { "id": "XvqtZChKvLr0" }, "source": [ "In addition to running on every commit,\n", "`pre-commit` separates the model development environment from the environments\n", "needed for the linting tools, preventing conflicts\n", "and simplifying maintenance and onboarding." ] }, { "cell_type": "markdown", "metadata": { "id": "Y0XuIuKOXhJl" }, "source": [ "This cell runs `pre-commit`.\n", "\n", "The first time it is run on a machine, it will install the environments for all tools." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hltYGbpNvLr1" }, "outputs": [], "source": [ "!pre-commit run --all-files" ] }, { "cell_type": "markdown", "metadata": { "id": "gLw08gIkvLr1" }, "source": [ "The output lists all the checks that are run and whether they are passed.\n", "\n", "Notice there are a number of simple version-control hygiene practices included\n", "that aren't even specific to Python, much less to machine learning.\n", "\n", "For example, several of the checks prevent accidental commits with private keys, large files, \n", "leftover debugger statements, or merge conflict annotations in them." ] }, { "cell_type": "markdown", "metadata": { "id": "RHEEjb9kvLr1" }, "source": [ "These linting actions are configured via\n", "([what else?](https://twitter.com/charles_irl/status/1446235836794564615?s=20&t=OOK-9NbgbJAoBrL8MkUmuA))\n", "a YAML file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dgXa8BzrvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml" ] }, { "cell_type": "markdown", "metadata": { "id": "8HYc_WbTvLr2" }, "source": [ "Most of the general cleanliness checks are from hooks built by `pre-commit`.\n", "\n", "See the comments and links in the `.pre-commit-config.yaml` for more:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K9rTgRqzvLr2" }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep repos -A 15" ] }, { "cell_type": "markdown", "metadata": { "id": "1ptkO7aPvLr2" }, "source": [ "Let's take a look at the section of the file\n", "that applies most of our Python style enforcement with\n", "[`flake8`](https://flake8.pycqa.org/en/latest/):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ALsRKfcevLr3", "scrolled": true }, "outputs": [], "source": [ "!cat .pre-commit-config.yaml | grep \"flake8 python\" -A 10" ] }, { "cell_type": "markdown", "metadata": { "id": "a_Q0BwQUXbg6" }, "source": [ "The majority of the style checking behavior we want comes from the\n", "`additional_dependencies`, which are\n", "[plugins](https://flake8.pycqa.org/en/latest/glossary.html#term-plugin)\n", "that extend `flake8`'s list of lints.\n", "\n", "Notice that we have a `--config` file passed in to the `args` for the `flake8` command.\n", "\n", "We keep the configuration information for `flake8`\n", "separate from that for `pre-commit`\n", "in case we want to use additional tools with `flake8`,\n", "e.g. if some developers want to integrate it directly into their editor,\n", "and so that if we change away from `.pre-commit`\n", "but keep `flake8` we don't have to\n", "recreate our configuration in a different tool.\n", "\n", "As much as possible, codebases should strive for single sources of truth\n", "and link back to those sources of truth with documentation or comments,\n", "as in the last line above.\n", "\n", "Let's take a look at the contents of `flake8`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "doC_4WQwvLr3" }, "outputs": [], "source": [ "!cat .flake8" ] }, { "cell_type": "markdown", "metadata": { "id": "0Nq6HnyU0M47" }, "source": [ "There's a lot here! We'll focus on the most important bits." ] }, { "cell_type": "markdown", "metadata": { "id": "U4PiB8CPvLr3" }, "source": [ "Linting tools in Python generally work by emitting error codes\n", "with one or more letters followed by three numbers.\n", "The `select` argument picks which error codes we want to check for.\n", "Error codes are matched by prefix,\n", "so for example `B` matches `BTS101` and\n", "`G1` matches `G102` and `G199` but not `ARG404`.\n", "\n", "Certain codes are `ignore`d in the default `flake8` style,\n", "which is done via the `ignore` argument,\n", "and we can `extend` the list of `ignore`d codes with `extend-ignore`.\n", "For example, we rely on `black` to do our formatting,\n", "so we ignore some of `flake8`'s formatting codes.\n", "\n", "Together, these settings define our project's particular style.\n", "\n", "But not every file fits this style perfectly.\n", "Most of the conventions in `black` and `flake8` come from the style-defining\n", "[Python Enhancement Proposal 8](https://peps.python.org/pep-0008/),\n", "which exhorts you to \"know when to be inconsistent\".\n", "\n", "To allow ourselves to be inconsistent when we know we should be,\n", "`flake8` includes `per-file-ignores`,\n", "which let us ignore specific warnings in specific files.\n", "This is one of the \"escape valves\"\n", "that makes style enforcement tolerable.\n", "We can also `exclude` files in the `pre-commit` config itself.\n", "\n", "For details on selecting and ignoring,\n", "see the [`flake8` docs](https://flake8.pycqa.org/en/latest/user/violations.html)\n", "\n", "For definitions of the error codes from `flake8` itself,\n", "see the [list in the docs](https://flake8.pycqa.org/en/latest/user/error-codes.html).\n", "Individual extensions list their added error codes in their documentation,\n", "e.g. `darglint` does so\n", "[here](https://github.com/terrencepreilly/darglint#error-codes)." ] }, { "cell_type": "markdown", "metadata": { "id": "NL0TpyPsvLr4" }, "source": [ "The remainder are configurations for the other `flake8` plugins that we use to define and enforce the rest of our style.\n", "\n", "You can read more about each in their documentation:\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "markdown", "metadata": { "id": "mFsZC0a7vLr4" }, "source": [ "### Linting via a script and using `shellcheck`" ] }, { "cell_type": "markdown", "metadata": { "id": "RYjpuFwjXkJc" }, "source": [ "To avoid needing to think about `pre-commit`\n", "(was the command `pre-commit run` or `pre-commit check`?)\n", "while developing locally,\n", "we might put our linters into a shell script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mXlLFWmavLr4" }, "outputs": [], "source": [ "!cat tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "PPxHpRIB3nbw" }, "source": [ "These kinds of short and simple shell scripts are common in projects\n", "of intermediate size.\n", "\n", "They are useful for adding automation and reducing friction." ] }, { "cell_type": "markdown", "metadata": { "id": "TMuPBpAi2qwl" }, "source": [ "But these scripts are code,\n", "and all code is susceptible to bugs and subject to concerns of style consistency." ] }, { "cell_type": "markdown", "metadata": { "id": "SQRg3ZqXvLr4" }, "source": [ "We can't check these scripts with tools that lint Python code,\n", "so we include a shell script linting tool,\n", "[`shellcheck`](https://www.shellcheck.net/),\n", "in our `pre-commit`.\n", "\n", "More so than checking for correct style,\n", "this tool checks for common bugs or surprising behaviors of shells,\n", "which are unfortunately numerous." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zkfhE1srvLr4" }, "outputs": [], "source": [ "script_filename = \"tasks/lint.sh\"\n", "!pre-commit run shellcheck --files {script_filename}" ] }, { "cell_type": "markdown", "metadata": { "id": "KXU9TRrwvLr4" }, "source": [ "That script has already been tested, so we don't see any errors.\n", "\n", "Try copying over a script you've written yourself or\n", "even from a popular repo that you like\n", "(by adding to the notebook directory or by making a cell\n", "with `%%writefile` at the top)\n", "and test it by changing the `script_filename`.\n", "\n", "You'd be surprised at the classes of subtle bugs possible in bash!" ] }, { "cell_type": "markdown", "metadata": { "id": "81MhAL-TvLr5" }, "source": [ "### Try \"unofficial bash strict mode\" for louder failures in scripts" ] }, { "cell_type": "markdown", "metadata": { "id": "hSwhs_zUvLr5" }, "source": [ "Another way to reduce bugs is to use the suggested \"unofficial bash strict mode\" settings by\n", "[@redsymbol](https://twitter.com/redsymbol),\n", "which appear at the top of the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o-j0vSxEvLr5" }, "outputs": [], "source": [ "!head -n 3 tasks/lint.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "d2iJU5jlvLr5" }, "source": [ "The core idea of strict mode is to fail more loudly.\n", "This is a desirable behavior of scripts,\n", "like the ones we're writing,\n", "even though it's an undesirable behavior for an interactive shell --\n", "it would be unpleasant to be logged out every time you hit an error.\n", "\n", "`set -u` means scripts fail if a variable's value is `u`nset,\n", "i.e. not defined.\n", "Otherwise bash is perfectly happy to allow you to reference undefined variables.\n", "The result is just an empty string, which can lead to maddeningly weird behavior.\n", "\n", "`set -o pipefail` means failures inside a pipe of commands (`|`) propagate,\n", "rather than using the exit code of the last command.\n", "Unix tools are perfectly happy to work on nonsense input,\n", "like sorting error messages, instead of the filenames you meant to send.\n", "\n", "You can read more about these choices\n", "[here](http://redsymbol.net/articles/unofficial-bash-strict-mode/),\n", "and considerations for working with other non-conforming scripts in \"strict mode\"\n", "and for handling resource teardown when scripts error out." ] }, { "cell_type": "markdown", "metadata": { "id": "s1XqsrU_XWWS" }, "source": [ "# Testing ML Codebases" ] }, { "cell_type": "markdown", "metadata": { "id": "CPNzeq3NYF2W" }, "source": [ "## Testing Python code with `pytests`" ] }, { "cell_type": "markdown", "metadata": { "id": "zq5e_x6gc9Vu" }, "source": [ "\n", "ML codebases are Python first and foremost, so first let's get some Python tests going." ] }, { "cell_type": "markdown", "metadata": { "id": "0DC3GxYz6_R9" }, "source": [ "At a basic level,\n", "we can write functions that `assert`\n", "that our code behaves as expected in\n", "a given scenario and include it in the same module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Rvd-GNwv63W1" }, "outputs": [], "source": [ "from text_recognizer.lit_models.metrics import test_character_error_rate\n", "\n", "test_character_error_rate??" ] }, { "cell_type": "markdown", "metadata": { "id": "iVB2TsQS5BTq" }, "source": [ "The standard tool for testing Python code is\n", "[`pytest`]((https://docs.pytest.org/en/7.1.x/)).\n", "\n", "We can use it as a command-line tool in a variety of ways,\n", "including to execute these kinds of tests.\n", "\n", "If passed a filename, `pytest` will look for\n", "any classes that start with `Test` or\n", "any functions that start with `test_` and run them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u8sQguyJvLr6", "scrolled": false }, "outputs": [], "source": [ "!pytest text_recognizer/lit_models/metrics.py" ] }, { "cell_type": "markdown", "metadata": { "id": "92tkBCllvLr6" }, "source": [ "After the results of the tests (pass or fail) are returned,\n", "you'll see a report of \"coverage\" from\n", "[`codecov`](https://about.codecov.io/).\n", "\n", "This coverage report tells us which files and how many lines in those files\n", "were at touched by the testing suite." ] }, { "cell_type": "markdown", "metadata": { "id": "PllSUe0s5xvU" }, "source": [ "We do not actually need to provide the names of files with tests in them to `pytest`\n", "in order for it to run our tests." ] }, { "cell_type": "markdown", "metadata": { "id": "4qOBHJnTZM9x" }, "source": [ "By default, `pytest` looks for any files named `test_*.py` or `*_test.py`.\n", "\n", "It's [good practice](https://docs.pytest.org/en/7.1.x/explanation/goodpractices.html#test-discovery)\n", "to separate these from the rest of your code\n", "in a folder or folders named `tests`,\n", "rather than scattering them around the repo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "acjsYTNSvLr6" }, "outputs": [], "source": [ "!ls text_recognizer/tests" ] }, { "cell_type": "markdown", "metadata": { "id": "WZQQZUF0vLr6" }, "source": [ "Let's take a look at a specific example:\n", "the tests for some of our utilities around\n", "custom PyTorch Lightning `Callback`s." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oS0xKv1evLr6" }, "outputs": [], "source": [ "from text_recognizer.tests import test_callback_utils\n", "\n", "\n", "test_callback_utils.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "lko8msn-vLr7" }, "source": [ "Notice that we can easily import this as a module!\n", "\n", "That's another benefit of organizing tests into specialized files." ] }, { "cell_type": "markdown", "metadata": { "id": "5A85FUNv75Fr" }, "source": [ "The particular utility we're testing\n", "here is designed to prevent crashes:\n", "it checks for a particular type of error and turns it into a warning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Jl4-DiVe76sw" }, "outputs": [], "source": [ "from text_recognizer.callbacks.util import check_and_warn\n", "\n", "check_and_warn??" ] }, { "cell_type": "markdown", "metadata": { "id": "B6E0MhduvLr7" }, "source": [ "Error-handling code is a common cause of bugs,\n", "a fact discovered\n", "[again and again across forty years of error analysis](https://twitter.com/full_stack_dl/status/1561880960886505473?s=20&t=5OZBonILaUJE9J4ah2Qn0Q),\n", "so it's very important to test it well!\n", "\n", "We start with a very basic test,\n", "which does not touch anything\n", "outside of the Python standard library,\n", "even though this tool is intended to be used\n", "with more complex features of third-party libraries,\n", "like `wandb` and `tensorboard`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xx5koQmJvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_simple??" ] }, { "cell_type": "markdown", "metadata": { "id": "MZe9-JVjvLr7" }, "source": [ "Here, we are just testing the core logic.\n", "This test won't catch many bugs,\n", "but when it does fail, something has gone seriously wrong.\n", "\n", "These kinds of tests are important for resolving a bug:\n", "we learn nearly as much from the tests that passed\n", "as we did from the tests that failed.\n", "If this test has failed, possibly along with others,\n", "we can rule out an issue in one of the large external codebases\n", "touched in the other tests, saving us lots of time in our troubleshooting.\n", "\n", "The reasoning for the test is explained in the docstrings, \n", "which are close to the code.\n", "\n", "Your test suite should be as welcoming\n", "as the rest of your codebase!\n", "The people reading it, for example yourself in six months, \n", "are likely upset and in need of some kindness.\n", "\n", "More practically, we want keep our time to resolve errors as short as possible,\n", "and five minutes to write a good docstring now\n", "can save five minutes during an outage, when minutes really matter." ] }, { "cell_type": "markdown", "metadata": { "id": "Om9k-uXhvLr7" }, "source": [ "That basic test is a start, but it's not enough by itself.\n", "There's a specific error case that triggered the addition of this code.\n", "\n", "So we test that it's handled as expected." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fjbsb5FvvLr7" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_tblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "CGAIZTUjvLr7" }, "source": [ "That test can fail if the libraries change around our code,\n", "i.e. if the `TensorBoardLogger` gets a `log_table` method.\n", "\n", "We want to be careful when making assumptions\n", "about other people's software,\n", "especially for fast-moving libraries like Lightning.\n", "If we test that those assumptions hold willy-nilly,\n", "we'll end up with tests that fail because of\n", "harmless changes in our dependencies.\n", "\n", "Tests that require a ton of maintenance and updating\n", "without leading to code improvements soak up\n", "more engineering time than they save\n", "and cause distrust in the testing suite.\n", "\n", "We include this test because `TensorBoardLogger` getting\n", "a `log_table` method will _also_ change the behavior of our code\n", "in a breaking way, and we want to catch that before it breaks\n", "a model training job." ] }, { "cell_type": "markdown", "metadata": { "id": "jsy95KAvvLr7" }, "source": [ "Adding error handling can also accidentally kill the \"happy path\"\n", "by raising an error incorrectly.\n", "\n", "So we explicitly test the _absence of an error_,\n", "not just its presence:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LRlIOkjmvLr8" }, "outputs": [], "source": [ "test_callback_utils.test_check_and_warn_wandblogger??" ] }, { "cell_type": "markdown", "metadata": { "id": "osiqpLynvLr8" }, "source": [ "There are more tests we could build, e.g. manipulating classes and testing the behavior,\n", "testing more classes that might be targeted by `check_and_warn`, or\n", "asserting that warnings are raised to the command line.\n", "\n", "But these three basic tests are likely to catch most changes that would break our code here,\n", "and they're a lot easier to write than the others.\n", "\n", "If this utility starts to get more usage and become a critical path for lots of features, we can always add more!" ] }, { "cell_type": "markdown", "metadata": { "id": "dm285JE5vLr8" }, "source": [ "## Interleaving testing and documentation with `doctests`" ] }, { "cell_type": "markdown", "metadata": { "id": "UHWQvgA8vLr8" }, "source": [ "One function of tests is to build user/reader confidence in code." ] }, { "cell_type": "markdown", "metadata": { "id": "wrhiJBXFvLr8" }, "source": [ "One function of documentation is to build user/reader knowledge in code." ] }, { "cell_type": "markdown", "metadata": { "id": "1vu12LDhvLr8" }, "source": [ "These functions are related. Let's put them together:\n", "put code in a docstring and test that code.\n", "\n", "This feature is part of the\n", "Python standard library via the\n", "[`doctest` module](https://docs.python.org/3/library/doctest.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "rmfIOwXd-Qt7" }, "source": [ "Here's an example from our `torch` utilities.\n", "\n", "The `first_appearance` function can be used to\n", "e.g. quickly look for stop tokens,\n", "giving the length of each sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZzURGcD9vLr8" }, "outputs": [], "source": [ "from text_recognizer.lit_models.util import first_appearance\n", "\n", "\n", "first_appearance??" ] }, { "cell_type": "markdown", "metadata": { "id": "0VtYcJ1WvLr8" }, "source": [ "Notice that in the \"Examples\" section,\n", "there's a short block of code formatted as a\n", "Python interpreter session,\n", "complete with outputs.\n", "\n", "We can copy and paste that code and\n", "check that we get the right outputs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dj4lNOxJvLr9" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y9AWHFoIvLr9" }, "source": [ "We can run the test with `pytest` by passing a command line argument,\n", "`--doctest-modules`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JMaAxv5ovLr9" }, "outputs": [], "source": [ "!pytest --doctest-modules text_recognizer/lit_models/util.py" ] }, { "cell_type": "markdown", "metadata": { "id": "6-2_aOUfvLr9" }, "source": [ "With the\n", "[right configuration](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/627dc9dabc9070cb14bfe5bfcb1d6131eb7dc7a8/pyproject.toml#L12-L17),\n", "running `doctest`s happens automatically\n", "when `pytest` is invoked." ] }, { "cell_type": "markdown", "metadata": { "id": "my_keokPvLr9" }, "source": [ "## Basic tests for data code" ] }, { "cell_type": "markdown", "metadata": { "id": "Qj3Bq_j2_A8o" }, "source": [ "ML code can be hard to test\n", "since it involes very heavy artifacts, like models and data,\n", "and very expensive jobs, like training." ] }, { "cell_type": "markdown", "metadata": { "id": "DT5OmgrQvLr9" }, "source": [ "For testing our data-handling code in the FSDL codebase,\n", "we mostly just use `assert`s,\n", "which throw errors when behavior differs from expectation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bdzn5g4TvLr9" }, "outputs": [], "source": [ "!grep \"assert\" -r text_recognizer/data" ] }, { "cell_type": "markdown", "metadata": { "id": "2aTlfu4_vLr-" }, "source": [ "This isn't great practice,\n", "especially as a codebase grows,\n", "because we can't easily know when these are executed\n", "or incorporate them into\n", "testing automation and coverage analysis tools." ] }, { "cell_type": "markdown", "metadata": { "id": "IaMTdmbZ_mkW" }, "source": [ "So it's preferable to collect up these assertions of simple data properties\n", "into tests that are run like our other tests.\n", "\n", "The test below checks whether any data is leaking\n", "between training, validation, and testing." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qx7cxiDdvLr-" }, "outputs": [], "source": [ "from text_recognizer.tests.test_iam import test_iam_data_splits\n", "\n", "\n", "test_iam_data_splits??" ] }, { "cell_type": "markdown", "metadata": { "id": "16TJwhd1vLr-" }, "source": [ "Notice that we were able to load the test into the notebook\n", "because it is in a module,\n", "and so we can run it here as well:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mArITFkYvLr-" }, "outputs": [], "source": [ "test_iam_data_splits()" ] }, { "cell_type": "markdown", "metadata": { "id": "E4F2uaclvLr-" }, "source": [ "But we're checking something pretty simple here,\n", "so the new code in each test is just a single line.\n", "\n", "What if we wanted to test more complex properties,\n", "like comparing rows or calculating statistics?\n", "\n", "We'll end up writing more complex code that might itself have subtle bugs,\n", "requiring tests for our tests and suffering from\n", "\"tester's regress\".\n", "\n", "This is the phenomenon,\n", "named by analogy with\n", "[experimenter's regress](https://en.wikipedia.org/wiki/Experimenter%27s_regress)\n", "in sociology of science,\n", "where the validity of our tests is itself\n", "up for dispute only resolvable by testing the tests,\n", "but those tests are themselves possibly invalid." ] }, { "cell_type": "markdown", "metadata": { "id": "nUGT06gdvLr-" }, "source": [ "We cut this Gordian knot by using\n", "a library or framework that is well-tested.\n", "\n", "We recommend checking out\n", "[`great_expectations`](https://docs.greatexpectations.io/docs/)\n", "if you're looking for a high-quality data testing tool." ] }, { "cell_type": "markdown", "metadata": { "id": "dQ5vNsq3vLr-" }, "source": [ "Especially with data, some tests are particularly \"heavy\" --\n", "they take a long time,\n", "and we might want to run them\n", "on different machines\n", "and on a different schedule\n", "than our other tests." ] }, { "cell_type": "markdown", "metadata": { "id": "xephcb0LvLr-" }, "source": [ "For example, consider testing whether the download of a dataset succeeds and gives the right checksum.\n", "\n", "We can't just use a cached version of the data,\n", "since that won't actually execute the code!\n", "\n", "This test will take\n", "as long to run\n", "and consume as many resources as\n", "a full download of the data." ] }, { "cell_type": "markdown", "metadata": { "id": "YSN4w2EqvLr-" }, "source": [ "`pytest` allows the separation of tests\n", "into suites with `mark`s,\n", "which \"tag\" tests with names." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V0rScrcXvLr_", "scrolled": false }, "outputs": [], "source": [ "!pytest --markers | head -n 10" ] }, { "cell_type": "markdown", "metadata": { "id": "lr5Ca7B0vLr_" }, "source": [ "We can choose to run tests with a given mark\n", "or to skip tests with a given mark, \n", "among other basic logical operations around combining and filtering marks,\n", "with `-m`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xmw-Eb1ZvLr_" }, "outputs": [], "source": [ "!wandb login # one test requires wandb authentication\n", "\n", "!pytest -m \"not data and not slow\"" ] }, { "cell_type": "markdown", "metadata": { "id": "5LuERxOXX_UJ" }, "source": [ "## Testing training with memorization tests" ] }, { "cell_type": "markdown", "metadata": { "id": "AnWLN4lRvLsA" }, "source": [ "Training is the process by which we convert inert data into executable models,\n", "so it is dependent on both.\n", "\n", "We decouple checking whether the script has a critical bug\n", "from whether the data or model code is broken\n", "by testing on some basic \"fake data\",\n", "based on a utility from `torchvision`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k4NIc3uWvLsA" }, "outputs": [], "source": [ "from text_recognizer.data import FakeImageData\n", "\n", "\n", "FakeImageData.__doc__" ] }, { "cell_type": "markdown", "metadata": { "id": "deN0swwlvLsA" }, "source": [ "We then test on the actual data with a smaller version of the real model.\n", "\n", "We use the Lightning `--fast_dev_run` feature,\n", "which sets the number of training, validation, and test batches to `1`.\n", "\n", "We use a smaller version so that this test can run in just a few minutes\n", "on a CPU without acceleration.\n", "\n", "That allows us to run our tests in environments without GPUs,\n", "which saves on costs for executing tests.\n", "\n", "Here's the script:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z4J0_uD9vLsA" }, "outputs": [], "source": [ "!cat training/tests/test_run_experiment.sh" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y-7u9zS1vLsA", "scrolled": false }, "outputs": [], "source": [ "! ./training/tests/test_run_experiment.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "UTzfo11KClV3" }, "source": [ "The above tests don't actaully check\n", "whether any learning occurs,\n", "they just check\n", "whether training runs mechanically,\n", "without any errors.\n", "\n", "We also need a\n", "[\"smoke test\"](https://en.wikipedia.org/wiki/Smoke_testing_(software))\n", "for learning.\n", "For that we recommending checking whether\n", "the model can learn the right\n", "outputs for a single batch --\n", "to \"memorize\" the outputs for\n", "a particular input.\n", "\n", "This memorization test won't\n", "catch every bug or issue in training,\n", "which is notoriously difficult,\n", "but it will flag\n", "some of the most serious issues." ] }, { "cell_type": "markdown", "metadata": { "id": "0DVSp3aAvLsA" }, "source": [ "The script below runs a memorization test." ] }, { "cell_type": "markdown", "metadata": { "id": "2DFVVrxpvLsA" }, "source": [ "It takes up to two arguments:\n", "a `MAX`imum number of `EPOCHS` to run for and\n", "a `CRITERION` value of the loss to test against.\n", "\n", "The test passes if the loss is lower than the `CRITERION` value\n", "after the `MAX`imum number of `EPOCHS` has passed." ] }, { "cell_type": "markdown", "metadata": { "id": "oEhJH0e5vLsB" }, "source": [ "The important line in this script is the one that invokes our training script,\n", "`training/run_experiment.py`.\n", "\n", "The arguments to `run_experiment` have been tuned for maximum possible speed:\n", "turning off regularization, shrinking the model,\n", "and skipping parts of Lightning that we don't want to test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T-fFs1xEvLsB" }, "outputs": [], "source": [ "!cat training/tests/test_memorize_iam.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "X-47tUA_YNGe" }, "source": [ "If you'd like to see what a memorization run looks like,\n", "flip the `running_memorization` flag to `True`\n", "and watch the results stream in to W&B.\n", "\n", "The cell should run in about ten minutes on a commodity GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GwTEsZwKvLsB" }, "outputs": [], "source": [ "%%time\n", "running_memorization = False\n", "\n", "if running_memorization:\n", " max_epochs = 1000\n", " loss_criterion = 0.05\n", " !./training/tests/test_memorize_iam.sh {max_epochs} {loss_criterion}" ] }, { "cell_type": "markdown", "metadata": { "id": "zPoFCoEcC8SV" }, "source": [ "# Troubleshooting model speed with the PyTorch Profiler" ] }, { "cell_type": "markdown", "metadata": { "id": "DpbN-Om2Drf-" }, "source": [ "Testing code is only half the story here:\n", "we also need to fix the issues that our tests flag.\n", "This is the process of troubleshooting.\n", "\n", "In this lab,\n", "we'll focus on troubleshooting model performance issues:\n", "what do to when your model runs too slowly." ] }, { "cell_type": "markdown", "metadata": { "id": "NZzwELPXvLsD" }, "source": [ "Troubleshooting deep neural networks for speed is challenging.\n", "\n", "There are at least three different common approaches,\n", "each with an increasing level of skill required:\n", "\n", "1. Follow best practices advice from others\n", "([this @karpathy tweet](https://t.co/7CIDWfrI0J), summarizing\n", "[this NVIDIA talk](https://www.youtube.com/watch?v=9mS1fIYj1So&ab_channel=ArunMallya), is a popular place to start) and use existing implementations.\n", "2. Take code that runs slowly and use empirical observations to iteratively improve it.\n", "3. Truly understand distributed, accelerated tensor computations so you can write code correctly from scratch the first time.\n", "\n", "For the full stack deep learning engineer,\n", "the final level is typically out of reach,\n", "unless you're specializing in the model performance\n", "part of the stack in particular.\n", "\n", "So we recommend reaching the middle level,\n", "and this segment of the lab walks through the\n", "tools that make this easier." ] }, { "cell_type": "markdown", "metadata": { "id": "3_yp87UrFZ8M" }, "source": [ "Because neural network training involves GPU acceleration,\n", "generic Python profiling tools like\n", "[`py-spy`](https://github.com/benfred/py-spy)\n", "won't work, and\n", "we'll need tools specialized for tracing and profiling DNN training." ] }, { "cell_type": "markdown", "metadata": { "id": "yspsYVFGEyZm" }, "source": [ "In general, these tools are for observing what happens while your code is executing:\n", "_tracing_ which operations were happening when and summarizing that into a _profile_ of the code.\n", "\n", "Because they help us observe the execution in detail,\n", "they will also help us understand just what is going on during\n", "a PyTorch training step in greater detail." ] }, { "cell_type": "markdown", "metadata": { "id": "YqXq2hKuvLsE" }, "source": [ "To support profiling and tracing,\n", "we've added a new argument to `training/run_experiment.py`, `--profile`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z_GMMViWvLsE" }, "outputs": [], "source": [ "!python training/run_experiment.py --help | grep -A 1 -e \"^\\s*--profile\\s\"" ] }, { "cell_type": "markdown", "metadata": { "id": "ZldoksHPvLsE" }, "source": [ "As with experiment management, this relies mostly on features of PyTorch Lightning,\n", "which themselves wrap core utilities from libraries like PyTorch and TensorBoard,\n", "and we just add a few lines of customization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F2iJ0_A6vLsE" }, "outputs": [], "source": [ "!cat training/run_experiment.py | grep args.profile -A 5" ] }, { "cell_type": "markdown", "metadata": { "id": "Aw3ppgndvLsE" }, "source": [ "For more on profiling with Lightning, see the\n", "[Lightning tutorial](https://pytorch-lightning.readthedocs.io/en/1.6.1/advanced/profiler.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "uCAmNW3QEtcD" }, "source": [ "The cell below runs an epoch of training with tracing and profiling turned on\n", "and then saves the results locally and to W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t4o3ylDgr46F", "scrolled": false }, "outputs": [], "source": [ "import glob\n", "\n", "import torch\n", "import wandb\n", "\n", "from text_recognizer.data.base_data_module import DEFAULT_NUM_WORKERS\n", "\n", "\n", "# make it easier to separate these from training runs\n", "%env WANDB_JOB_TYPE=profile\n", "\n", "batch_size = 16\n", "num_workers = DEFAULT_NUM_WORKERS # change this number later and see how the results change\n", "gpus = 1 # must be run with accelerator\n", "\n", "%run training/run_experiment.py --wandb --profile \\\n", " --max_epochs=1 \\\n", " --num_sanity_val_steps=0 --limit_val_batches=0 --limit_test_batches=0 \\\n", " --model_class=ResnetTransformer --data_class=IAMParagraphs --loss=transformer \\\n", " --batch_size={batch_size} --num_workers={num_workers} --precision=16 --gpus=1\n", "\n", "latest_expt = wandb.run\n", "\n", "try: # add execution trace to logged and versioned binaries\n", " folder = wandb.run.dir\n", " trace_matcher = wandb.run.dir + \"/*.pt.trace.json\"\n", " trace_file = glob.glob(trace_matcher)[0]\n", " trace_at = wandb.Artifact(name=f\"trace-{wandb.run.id}\", type=\"trace\")\n", " trace_at.add_file(trace_file, name=\"training_step.pt.trace.json\")\n", " wandb.log_artifact(trace_at)\n", "except IndexError:\n", " print(\"trace not found\")\n", "\n", "wandb.finish()" ] }, { "cell_type": "markdown", "metadata": { "id": "ePTkS3EqO5tN" }, "source": [ "We get out a table of statistics in the terminal,\n", "courtesy of Lightning.\n", "\n", "Each row lists an operation\n", "and and provides information,\n", "described in the column headers,\n", "about the time spent on that operation\n", "across all the training steps we profiled.\n", "\n", "With practice, some useful information can be read out from this table,\n", "but it's better to start from both a less detailed view,\n", "in the TensorBoard dashboard,\n", "and a more detailed view,\n", "using the Chrome Trace viewer." ] }, { "cell_type": "markdown", "metadata": { "id": "TzV62f3c7-Bi" }, "source": [ "## High-level statistics from the PyTorch Profiler in TensorBoard" ] }, { "cell_type": "markdown", "metadata": { "id": "mNPKXkYw8NWd" }, "source": [ "Let's look at the profiling info in a high-level TensorBoard dashboard, conveniently hosted for us on W&B." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CbItwuT88eAV" }, "outputs": [], "source": [ "your_tensorboard_url = latest_expt.url + \"/tensorboard\"\n", "\n", "print(your_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "jE_LooMYHFpF" }, "source": [ "If at any point you run into issues,\n", "like the description not matching what you observe,\n", "check out one of our example runs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "za2zybSwIo5C" }, "outputs": [], "source": [ "example_tensorboard_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59\"\n", "print(example_tensorboard_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "xlrhl1n4HYU6" }, "source": [ "Once the TensorBoard session has loaded up,\n", "we are dropped into the Overview\n", "(see [this screenshot](https://pytorch.org/tutorials/_static/img/profiler_overview1.png)\n", "for an example).\n", "\n", "In the top center, we see the **GPU Summary** for our system.\n", "\n", "In addition to the name of our GPU,\n", "there are a few configuration details and top-level statistics.\n", "They are (tersely) documented\n", "[here](https://github.com/pytorch/kineto/blob/main/tb_plugin/docs/gpu_utilization.md)." ] }, { "cell_type": "markdown", "metadata": { "id": "MmBhUDgDLhd1" }, "source": [ "- **[Compute Capability](https://developer.nvidia.com/cuda-gpus)**:\n", "this is effectively a coarse \"version number\" for your GPU hardware.\n", "It indexes which features are available,\n", "with more advanced features being available only at higher compute capabilities.\n", "It does not directly index the speed or memory of the GPU." ] }, { "cell_type": "markdown", "metadata": { "id": "voUgT6zuLyi0" }, "source": [ "- **GPU Utilization**: This metric represents the fraction of time an operation (a CUDA kernel) is running on the GPU. This is also reported by the `!nvidia-smi` command or in the sytem metrics tab in W&B. This metric will be our first target to increase." ] }, { "cell_type": "markdown", "metadata": { "id": "Yl-IndtXE4b4" }, "source": [ "- **[Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)**:\n", "for devices with compute capability of at least 7, you'll see information about how much your execution used DNN-specialized\n", "Tensor Cores.\n", "If you're running on an older GPU without Tensor Cores,\n", "you should consider upgrading.\n", "If you're running a more recent GPU but not seeing Tensor Core usage,\n", "you should switch to single precision floating point numbers,\n", "which Tensor Cores are specialized on." ] }, { "cell_type": "markdown", "metadata": { "id": "XxcUf0bBNXy_" }, "source": [ "- **Est. SM Efficiency** and **Est. Occupancy** are high-level summaries of the utilization of GPU hardware\n", "at a lower level than just whether something is running at all,\n", "as in utilization.\n", "Unlike utilization, reaching 100% is not generally feasible\n", "and sometimes not desirable.\n", "Increasing these numbers requires expertise in\n", "CUDA programming, so we'll target utilization instead." ] }, { "cell_type": "markdown", "metadata": { "id": "A88pQn4YMMKc" }, "source": [ "- **Execution Summary**: This table and pie chart indicates\n", "how much time within a profiled step\n", "was spent in each category.\n", "The value for \"kernel\" execution here\n", "is equal to the GPU utilization,\n", "and we want that number to be as close to 100%\n", "as possible.\n", "This summary helps us know which\n", "other operations are taking time,\n", "like memory being copied between CPU and GPU (`memcpy`)\n", "or `DataLoader`s executing on the CPU,\n", "so we can decide where the bottleneck is." ] }, { "cell_type": "markdown", "metadata": { "id": "6qjW1RlTQRPv" }, "source": [ "At the very bottom, you'll find a\n", "**Performance Recommendation**\n", "tab that sometimes suggests specific methods for improving performance.\n", "\n", "If this tab makes suggestions, you should certainly take them!" ] }, { "cell_type": "markdown", "metadata": { "id": "pWY5AhrcRQmJ" }, "source": [ "For more on using the profiler in TensorBoard,\n", "including some of the other, more detailed views\n", "available view the \"Views\" dropdown menu, see\n", "[this PyTorch tutorial](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html?highlight=profiler)." ] }, { "cell_type": "markdown", "metadata": { "id": "mQwrPY_H77H8" }, "source": [ "## Going deeper with the Chrome Trace Viewer" ] }, { "cell_type": "markdown", "metadata": { "id": "yhwo7fslvLsH" }, "source": [ "So far, we've seen summary-level information about our training steps\n", "in the table from Lightning and in the TensorBoard Overview.\n", "These give aggregate statistics about the computations that occurred,\n", "but understanding how to interpret those statistics\n", "and use them to speed up our networks\n", "requires understanding just what is\n", "happening in our training step.\n", "\n", "Fundamentally,\n", "all computations are processes that unfold in time.\n", "\n", "If we want to really understand our training step,\n", "we need to display it that way:\n", "what operations were occurring,\n", "on both the CPU and GPU,\n", "at each moment in time during the training step.\n", "\n", "This information on timing is collected in the trace.\n", "One of the best tools for viewing the trace over time\n", "is the [Chrome Trace Viewer](https://www.chromium.org/developers/how-tos/trace-event-profiling-tool/)." ] }, { "cell_type": "markdown", "metadata": { "id": "wUkZItxYc20A" }, "source": [ "Let's tour the trace we just logged\n", "with an aim to really understanding just\n", "what is happening when we call\n", "`training_step`\n", "and by extension `.forward`, `.backward`, and `optimizer.step`." ] }, { "cell_type": "markdown", "metadata": { "id": "9w9F2UA7Qctg" }, "source": [ "The Chrome Trace Viewer is built into W&B,\n", "so we can view our traces in their interface.\n", "\n", "The cell below embeds the trace inside the notebook,\n", "but you may wish to open it separately,\n", "with the \"Open page\" button or by navigating to the URL,\n", "so that you can interact with it\n", "as you read the description below.\n", "Display directly on W&B is also a bit less temperamental\n", "than display on W&B inside a notebook.\n", "\n", "Furthermore, note that the Trace Viewer was originally built as part of the Chromium project,\n", "so it works best in browsers in that lineage -- Chrome, Edge, and Opera.\n", "It also can interact poorly with browser extensions (e.g. ad blockers),\n", "so you may need to deactivate them temporarily in order to see it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OMUs4aby6Rfd" }, "outputs": [], "source": [ "trace_files_url = latest_expt.url.split(\"/runs/\")[0] + f\"/artifacts/trace/trace-{latest_expt.id}/latest/files/\"\n", "trace_url = trace_files_url + \"training_step.pt.trace.json\"\n", "\n", "example_trace_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json\"\n", "\n", "print(trace_url)\n", "IFrame(src=trace_url, height=frame_height * 1.5, width=\"100%\")" ] }, { "cell_type": "markdown", "metadata": { "id": "qNVpGeQtQjMG" }, "source": [ "> **Heads up!** We're about to do a tour of the\n", "> precise details of the tracing information logged\n", "> during the execution of the training code.\n", "> The only way to learn how to troubleshoot model performance\n", "> empirically is to look at the details,\n", "> but the details depend on the precise machine being used\n", "> -- GPU and CPU and RAM.\n", "> That means even within Colab,\n", "> these details change from session to session.\n", "> So if you don't observe a phenomenon or feature\n", "> described in the tour below, check out\n", "> [the example trace](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-67j1qxws/latest/files/training_step.pt.trace.json)\n", "> on W&B while reading through the next section of the lab,\n", "> and return to your trace once you understand the trace viewer better at the end.\n", "> Also, these are very much bleeding-edge expert developer tools, so the UX and integrations\n", "> can sometimes be a bit janky." ] }, { "cell_type": "markdown", "metadata": { "id": "kXMcBhnCgdN_" }, "source": [ "This trace reveals, in nanosecond-level detail,\n", "what's going on inside of a `training_step`\n", "on both the GPU and the CPU.\n", "\n", "Time is on the horizontal axis.\n", "Colored bars represent method calls,\n", "and the methods called by a method are placed underneath it vertically,\n", "a visualization known as an\n", "[icicle chart](https://www.brendangregg.com/flamegraphs.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "67BsNzDfVIeg" }, "source": [ "Let's orient ourselves with some gross features:\n", "the forwards pass,\n", "GPU kernel execution,\n", "the backwards pass,\n", "and the optimizer step." ] }, { "cell_type": "markdown", "metadata": { "id": "IBEFgtRCKqrh" }, "source": [ "### The forwards pass" ] }, { "cell_type": "markdown", "metadata": { "id": "5nYhiWesVMjK" }, "source": [ "Type in `resnet` to the search bar in the top-right.\n", "\n", "This will highlight the first part of the forwards passes we traced, the encoding of the images with a ResNet.\n", "\n", "It should be in a vertical block of the trace that says `thread XYZ (python)` next to it.\n", "\n", "You can click the arrows next to that tile to partially collapse these blocks.\n", "\n", "Next, type in `transformerdecoder` to highlight the second part of our forwards pass.\n", "It should be at roughly the same height.\n", "\n", "Clear the search bar so that the trace is in color.\n", "Zoom in on the area of the forwards pass\n", "using the \"zoom\" tool in the floating toolbar,\n", "so you can see more detail.\n", "The zoom tool is indicated by a two-headed arrow\n", "pointing into and out of the screen.\n", "\n", "Switch to the \"drag\" tool,\n", "represented by a four-headed arrow.\n", "Click-and-hold to use this tool to focus\n", "on different parts of the timeline\n", "and click on the individual colored boxes\n", "to see details about a particular method call.\n", "\n", "As we go down in the icicle chart,\n", "we move from a very abstract level in Python (\"`resnet`\", \"`MultiheadAttention`\")\n", "to much more precise `cudnn` and `cuda` operations\n", "(\"`aten::cudnn_convolution`\", \"`aten::native_layer_norm`\").\n", "\n", "`aten` ([no relation to the Pharaoh](https://twitter.com/charles_irl/status/1422232585724432392?s=20&t=Jr4j5ZXhV20xGwUVD1rY0Q))\n", "is the tensor math library in PyTorch\n", "that links to specific backends like `cudnn`." ] }, { "cell_type": "markdown", "metadata": { "id": "Fq181ybIvLsH" }, "source": [ "### GPU kernel execution" ] }, { "cell_type": "markdown", "metadata": { "id": "IbkWp5aKvLsH" }, "source": [ "Towards the bottom, you should see a section labeled \"GPU\".\n", "The label appears on the far left.\n", "\n", "Within it, you'll see one or more \"`stream`s\".\n", "These are units of work on a GPU,\n", "akin loosely to threads on the CPU.\n", "\n", "When there are colored bars in this area,\n", "the GPU is doing work of some kind.\n", "The fraction of this bar that is filled in with color\n", "is the same as the \"GPU Utilization %\" we've seen previously.\n", "So the first thing to visually assess\n", "in a trace view of PyTorch code\n", "is what fraction of this area is filled with color.\n", "\n", "In CUDA, work is queued up to be\n", "placed into streams and completed, on the GPU,\n", "in a distributed and asynchronous manner.\n", "\n", "The selection of which work to do\n", "is happening on the CPU,\n", "and that's what we were looking at above.\n", "\n", "The CPU and the GPU have to work together to coordinate\n", "this work.\n", "\n", "Type `cuda` into the search bar and you'll see these coordination operations happening:\n", "`cudaLaunchKernel`, for example, is the CPU telling the GPU what to do.\n", "\n", "Running the same PyTorch model\n", "with the same high level operations like `Conv2d` in different versions of PyTorch,\n", "on different GPUs, and even on tensors of different sizes will result\n", "in different choices of concrete kernel operation,\n", "e.g. different matrix multiplication algorithms.\n", "\n", "Type `sync` into the search bar and you'll see places where either work on the GPU\n", "or work on the CPU needs to await synchronization,\n", "e.g. copying data from the CPU to the GPU\n", "or the CPU waiting to decide what to do next\n", "on the basis of the contents of a tensor.\n", "\n", "If you see a \"sync\" block above an area\n", "where the stream on the GPU is empty,\n", "you've got a performance bottleneck due to synchronization\n", "between the CPU and GPU.\n", "\n", "To resolve the bottleneck,\n", "head up the icicle chart until you reach the recognizable\n", "PyTorch modules and operations.\n", "Find where they are called in your PyTorch module.\n", "That's a good place to review your code to understand why the synchronization is happening\n", "and removing it if it's not necessary." ] }, { "cell_type": "markdown", "metadata": { "id": "XeMPbu_jvLsI" }, "source": [ "### The backwards pass\n", "\n", "Type in `backward` into the search bar.\n", "\n", "This will highlight components of our backwards pass.\n", "\n", "If you read it from left to right,\n", "you'll see that it begins by calculating the loss\n", "(`NllLoss2DBackward` in the search bar if you can't find it)\n", "and ends by doing a `ConvolutionBackward`,\n", "the first layer of the ResNet.\n", "It is, indeed, backwards.\n", "\n", "Like the forwards pass,\n", "the backwards pass also involves the CPU\n", "telling the GPU which kernels to run.\n", "It's typically run in a separate\n", "thread from the forwards pass,\n", "so you'll see it separated out from the forwards pass\n", "in the trace viewer.\n", "\n", "Generally, there's no need to specifically optimize the backwards pass --\n", "removing bottlenecks in the forwards pass results in a fast backwards pass.\n", "\n", "One reason why is that these two passes are just\n", "\"transposes\" of one another,\n", "so they share a lot of properties,\n", "and bottlenecks in one become bottlenecks in the other.\n", "We can choose to optimize either one of the two.\n", "But the forwards pass is under our direct control,\n", "so it's easier for us to reason about.\n", "\n", "Another reason is that the forwards pass is more likely to have bottlenecks.\n", "The forwards pass is a dynamic process,\n", "with each line of Python adding more to the compute graph.\n", "Backwards passes, on the other hand, use a static compute graph,\n", "the one just defined by the forwards pass,\n", "so more optimizations are possible." ] }, { "cell_type": "markdown", "metadata": { "id": "gWiDw0vCvLsI" }, "source": [ "### The optimizer step" ] }, { "cell_type": "markdown", "metadata": { "id": "ndfkzEdnvLsI" }, "source": [ "Type in `Adam.step` to the search bar to highlight the computations of the optimizer.\n", "\n", "As with the two passes,\n", "we are still using the CPU\n", "to launch kernels on the GPU.\n", "But now the CPU is looping,\n", "in Python, over the parameters\n", "and applying the ADAM updates rules to each.\n", "\n", "We now know enough to see that\n", "this is not great for our GPU utilization:\n", "there are many areas of gray\n", "in between the colored bars\n", "in the GPU stream in this area.\n", "\n", "In the time it takes CUDA to multiply\n", "thousands of numbers,\n", "Python has not yet finished cleaning up\n", "after its request for that multiplication.\n", "\n", "As of writing in August 2022,\n", "more efficient optimizers are not a stable part of PyTorch (v1.12), but\n", "[there is an unstable API](https://github.com/pytorch/pytorch/issues/68041)\n", "and stable implementations outside of PyTorch.\n", "The standard implementations are in\n", "[in NVIDIA's `apex.optimizers` library](https://nvidia.github.io/apex/optimizers.html),\n", "not to be confused with the\n", "[Apex Optimizers Project](https://www.apexoptimizers.com/),\n", "which is a collection of fitness-themed cheetah NFTs." ] }, { "cell_type": "markdown", "metadata": { "id": "WX0jxeafvLsI" }, "source": [ "## Take-aways for PyTorch performance bottleneck troubleshooting" ] }, { "cell_type": "markdown", "metadata": { "id": "CugD-bK2vLsI" }, "source": [ "Our goal here was to learn some basic principles and tools for bottlenecking\n", "the most common issues and the lowest-hanging fruit in PyTorch code." ] }, { "cell_type": "markdown", "metadata": { "id": "SwHwJkVMHYGA" }, "source": [ "\n", "Here's an overview in terms of a \"host\",\n", "generally the CPU,\n", "and a \"device\", here the GPU.\n", "\n", "- The slow-moving host operates at the level of an abstract compute graph (\"convolve these weights with this input\"), not actual numerical computations.\n", "- During execution, host's memory stores only metadata about tensors, like their types and shapes. This metadata needed to select the concrete operations, or CUDA kernels, for the device to run.\n", " - Convolutions with very large filter sizes, for example, might use fast Fourier transform-based convolution algorithms, while the smaller filter sizes typical of contemporary CNNs are generally faster with Winograd-style convolution algorithms.\n", "- The much beefier device executes actual operations, but has no control over which operations are executed. Its memory\n", "stores information about the contents of tensors,\n", "not just their metadata." ] }, { "cell_type": "markdown", "metadata": { "id": "Gntx28p9cBP5" }, "source": [ "Towards that goal, we viewed the trace to get an understanding of\n", "what's going on inside a PyTorch training step." ] }, { "cell_type": "markdown", "metadata": { "id": "AKvZGPnkeXvq" }, "source": [ "Here's what we've means in terms of troubleshooting bottlenecks.\n", "\n", "We want Python to chew its way through looking up the right CUDA kernel and telling the GPU that's what it needs next\n", "before the previous kernel finishes.\n", "\n", "Ideally, the CPU is actually getting far _ahead_ of execution\n", "on the GPU.\n", "If the CPU makes it all the way through the backwards pass before the GPU is done,\n", "that's great!\n", "The GPU(s) are the expensive part,\n", "and it's easy to use multiprocessing so that\n", "the CPU has other things to do.\n", "\n", "This helps explain at least one common piece of advice:\n", "the larger our batches are,\n", "the more work the GPU has to do for the same work done by the CPU,\n", "and so the better our utilization will be." ] }, { "cell_type": "markdown", "metadata": { "id": "XMztpa-TccH4" }, "source": [ "We operationalize our desire to never be waiting on the CPU with a simple metric:\n", "**100% GPU utilization**, meaning a kernel is running at all times.\n", "\n", "This is the aggregate metric reported in the systems tab on W&B or in the output of `!nvidia-smi`.\n", "\n", "You should not buy faster GPUs until you have maxed this out! If you have 50% utilization, the fastest GPU in the world can't give you more than a 2x speedup, and it will more than 2x cost." ] }, { "cell_type": "markdown", "metadata": { "id": "7kYBygfScR6z" }, "source": [ "Here are some of the most common issues that lead to low GPU Utilization, and how to resolve them:\n", "1. **The CPU is too weak**.\n", "Because so much of the discussion around DNN performance is about GPUs,\n", "it's easy when specing out a machine to skimp on the CPUs, even though training can bottleneck on CPU operations.\n", "_Resolution_:\n", "Use nice CPUs, like\n", "[threadrippers](https://www.amd.com/en/products/ryzen-threadripper).\n", "2. **Too much Python during the `training_step`**.\n", "Python is very slow, so if you throw in a really slow Python operation, like dynamically creating classes or iterating over a bunch of bytes, especially from disk, during the training step, you can end up waiting on a `__init__`\n", "that takes longer than running an entire layer.\n", "_Resolution_:\n", "Look for low utilization areas of the trace\n", "and check what's happening on the CPU at that time\n", "and carefully review the Python code being executed.\n", "3. **Unnecessary Host/Device synchronization**.\n", "If one of your operations depends on the values in a tensor,\n", "like `if xs.mean() >= 0`,\n", "you'll induce a synchronization between\n", "the host and the device and possibly lead\n", "to an expensive and slow copy of data.\n", "_Resolution_:\n", "Replace these operations as much as possible\n", "with purely array-based calculations.\n", "4. **Bottlenecking on the DataLoader**.\n", "In addition to coordinating the work on the GPU,\n", "CPUs often perform heavy data operations,\n", "including communication over the network\n", "and writing to/reading from disk.\n", "These are generally done in parallel to the forwards\n", "and backwards passes,\n", "but if they don't finish before that happens,\n", "they will become the bottleneck.\n", "_Resolution_:\n", "Get better hardware for compute,\n", "memory, and network.\n", "For software solutions, the answer \n", "is a bit more complex and application-dependent.\n", "For generic tips, see\n", "[this classic post by Ross Wightman](https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548/19)\n", "in the PyTorch forums.\n", "For techniques in computer vision, see\n", "[the FFCV library](https://github.com/libffcv/ffcv)\n", "and for techniques in NLP, see e.g.\n", "[Hugging Face datasets with Arrow](https://huggingface.co/docs/datasets/about_arrow)\n", "and [Hugging Face FastTokenizers](https://huggingface.co/course/chapter6/3)." ] }, { "cell_type": "markdown", "metadata": { "id": "i2WYS8bQvLsJ" }, "source": [ "### Further steps in making DNNs go brrrrrr" ] }, { "cell_type": "markdown", "metadata": { "id": "T0wW2_lRKfY1" }, "source": [ "It's important to note that utilization\n", "is just an easily measured metric\n", "that can reveal common bottlenecks.\n", "Having high utilization does not automatically mean\n", "that your performance is fully optimized.\n", "\n", "For example,\n", "synchronization events between GPUs\n", "are counted as kernels,\n", "so a deadlock during distributed training\n", "can show up as 100% utilization,\n", "despite literally no useful work occurring.\n", "\n", "Just switching to \n", "double precision floats, `--precision=64`,\n", "will generally lead to much higher utilization.\n", "The GPU operations take longer\n", "for roughly the same amount of CPU effort,\n", "but the added precision brings no benefit.\n", "\n", "In particular, it doesn't make for models\n", "that perform better on our correctness metrics,\n", "like loss and accuracy.\n", "\n", "Another useful yardstick to add\n", "to utilization is examples per second,\n", "which incorporates how quickly the model is processing data examples\n", "and calculating gradients.\n", "\n", "But really,\n", "the gold star is _decrease in loss per second_.\n", "This metric connects model design choices\n", "and hyperparameters with purely engineering concerns,\n", "so it disrespects abstraction barriers\n", "and doesn't generally lead to actionable recommendations,\n", "but it is, in the end, the real goal:\n", "make the loss go down faster so we get better models sooner." ] }, { "cell_type": "markdown", "metadata": { "id": "EFzPsplfdo_o" }, "source": [ "For PyTorch internals abstractly,\n", "see [Ed Yang's blog post](http://blog.ezyang.com/2019/05/pytorch-internals/).\n", "\n", "For more on performance considerations in PyTorch,\n", "see [Horace He's blog post](https://horace.io/brrr_intro.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "RFx-OhF837Bp" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "yq6-S6TC38AY" }, "source": [ "### 🌟 Compare `num_workers=0` with `DEFAULT_NUM_WORKERS`.\n", "\n", "One of the most important features for making\n", "PyTorch run quickly is the\n", "`MultiprocessingDataLoader`,\n", "which executes batching of data in a separate process\n", "from the forwards and backwards passes.\n", "\n", "By default in PyTorch,\n", "this feature is actually turned off,\n", "via the `DataLoader` argument `num_workers`\n", "having a default value of `0`,\n", "but we set the `DEFAULT_NUM_WORKERS`\n", "to a value based on the number of CPUs\n", "available on the system running the code.\n", "\n", "Re-run the profiling cell,\n", "but set `num_workers` to `0`\n", "to turn off multiprocessing.\n", "\n", "Compare and contrast the two traces,\n", "both for total runtime\n", "(see the time axis at the top of the trace)\n", "and for utilization.\n", "\n", "If you're unable to run the profiles,\n", "see the results\n", "[here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/artifacts/trace/trace-2eddoiz7/v0/files/training_step.pt.trace.json#f388e363f107e21852d5$trace-67j1qxws),\n", "which juxtaposes two traces,\n", "with in-process dataloading on the left and\n", "multiprocessing dataloading on the right." ] }, { "cell_type": "markdown", "metadata": { "id": "5D39w0gXAiha" }, "source": [ "### 🌟🌟 Resolve issues with a file by fixing flake8 lints, then write a test." ] }, { "cell_type": "markdown", "metadata": { "id": "T2i_a5eVeIoA" }, "source": [ "The file below incorrectly implements and then incorrectly tests\n", "a simple PyTorch utility for adding five to every entry of a tensor\n", "and then calculating the sum.\n", "\n", "Even worse, it does it with horrible style!\n", "\n", "The cells below apply our linting checks\n", "(after automatically fixing the formatting)\n", "and run the test.\n", "\n", "Fix all of the lints,\n", "implement the function correctly,\n", "and then implement some basic tests." ] }, { "cell_type": "markdown", "metadata": { "id": "wSon2fB5VVM_" }, "source": [ "- [`flake8`](https://flake8.pycqa.org/en/latest/user/error-codes.html) for core style\n", "- [`flake8-import-order`](https://github.com/PyCQA/flake8-import-order) for checking imports\n", "- [`flake8-docstrings`](https://github.com/pycqa/flake8-docstrings) for docstring style\n", "- [`darglint`](https://github.com/terrencepreilly/darglint) for docstring completeness\n", "- [`flake8-annotations`](https://github.com/sco1/flake8-annotations) for type annotations" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aYiRvU4HA84t" }, "outputs": [], "source": [ "%%writefile training/fixme.py\n", "import torch\n", "from training import run_experiment\n", "from numpy import *\n", "import random\n", "from pathlib import Path\n", "\n", "\n", "\n", "\n", "def add_five_and_sum(tensor):\n", " # this function is not implemented right,\n", " # but it's supposed to add five to all tensor entries and sum them up\n", " return 1\n", "\n", "def test_add_five_and_sum():\n", " # and this test isn't right either! plus this isn't exactly a docstring\n", " all_zeros, all_ones = torch.zeros((2, 3)), torch.ones((1, 4, 72))\n", " all_fives = 5 * all_ones\n", " assert False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EXJpmvuzT1w0" }, "outputs": [], "source": [ "!pre-commit run black --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SRO-oJfdUrcQ" }, "outputs": [], "source": [ "!cat training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jM8NHxVbSEQD" }, "outputs": [], "source": [ "!pre-commit run --files training/fixme.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kj0VMBSndtkc" }, "outputs": [], "source": [ "!pytest training/fixme.py" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "lab05_troubleshooting.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab08/notebooks/lab06_data.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "FlH0lCOttCs5" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "ZUPRHaeetRnT" }, "source": [ "# Lab 06: Data Annotation" ] }, { "cell_type": "markdown", "metadata": { "id": "bry3Hr-PcgDs" }, "source": [ "### What You Will Learn\n", "\n", "- How the `IAM` handwriting dataset is structured on disk and how it is processed into an ML-friendly format\n", "- How to setup a [Label Studio](https://labelstud.io/) data annotation server\n", "- Just how messy data really is" ] }, { "cell_type": "markdown", "metadata": { "id": "vs0LXXlCU6Ix" }, "source": [ "# Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "ZkQiK7lkgeXm" }, "source": [ "If you're running this notebook on Google Colab,\n", "the cell below will run full environment setup.\n", "\n", "It should take about three minutes to run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVx7C7H0PIZC" }, "outputs": [], "source": [ "lab_idx = 6\n", "\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " # needed for inline plots in some contexts\n", " %matplotlib inline\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "DpvaHz9TEGwV" }, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gsXpeXi2EGwV" }, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "\n", "IFrame(src=\"https://fsdl.me/2022-lab-06-video-embed\", width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "XTkKzEMNR8XZ" }, "source": [ "# `IAMParagraphs`: From annotated data to a PyTorch `Dataset`" ] }, { "cell_type": "markdown", "metadata": { "id": "3mQLbjuiwZuj" }, "source": [ "We've used the `text_recognizer.data` submodule\n", "and its `LightningDataModule`s -- `IAMLines` and `IAMParagraphs`\n", "for lines and paragraphs of handwritten text\n", "from the\n", "[IAM Handwriting Database](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database).\n", "\n", "These classes convert data from a database-friendly format\n", "designed for storage and transfer into the\n", "format our DNNs expect:\n", "PyTorch `Tensor`s.\n", "\n", "In this section,\n", "we'll walk through that process in detail.\n", "\n", "In the following section,\n", "we'll see how data\n", "goes from signals measured in the world\n", "to the format we consume here." ] }, { "cell_type": "markdown", "metadata": { "id": "499c23a6" }, "source": [ "## Dataset structure on disk" ] }, { "cell_type": "markdown", "metadata": { "id": "a3438d2e" }, "source": [ "We begin by downloading the raw data to disk." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "18900eec" }, "outputs": [], "source": [ "from text_recognizer.data.iam import IAM\n", "\n", "iam = IAM()\n", "iam.prepare_data()" ] }, { "cell_type": "markdown", "metadata": { "id": "a332f359" }, "source": [ "The `IAM` dataset is downloaded as zip file\n", "and then unzipped:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d6c44266" }, "outputs": [], "source": [ "from text_recognizer.metadata.iam import DL_DATA_DIRNAME\n", "\n", "\n", "iam_dir = DL_DATA_DIRNAME\n", "!ls {iam_dir}" ] }, { "cell_type": "markdown", "metadata": { "id": "8463c2d1" }, "source": [ "The unzipped dataset is not simple a flat directory of files.\n", "\n", "Instead, there are a number of subfolders,\n", "each of which contains a particular type of data or metadata." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "536924f7" }, "outputs": [], "source": [ "iamdb = iam_dir / \"iamdb\"\n", "\n", "!du -h {iamdb}" ] }, { "cell_type": "markdown", "metadata": { "id": "b745a594" }, "source": [ "For example, the `task` folder contains metadata about canonical dataset splits:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "84c21f75" }, "outputs": [], "source": [ "!find {iamdb / \"task\"} | grep \"\\\\.txt$\"" ] }, { "cell_type": "markdown", "metadata": { "id": "mEb0Pdm4vIHe" }, "source": [ "We find the images of handwritten text in the `forms` folder.\n", "\n", "An individual \"datapoint\" in `IAM` is a \"form\",\n", "because the humans whose hands wrote the text were prompted to write on \"forms\",\n", "as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "945d5e3a" }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "\n", "form_fn, = !find {iamdb}/forms | grep \".jpg$\" | sort | head -n 1\n", "\n", "print(form_fn)\n", "Image(filename=form_fn, width=\"360\")" ] }, { "cell_type": "markdown", "metadata": { "id": "b9e9e384" }, "source": [ "Meanwhile, the `xml` files contain the data annotations,\n", "written out as structured text:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6add5c5a" }, "outputs": [], "source": [ "xml_fn, = !find {iamdb}/xml | grep \"\\.xml$\" | sort | head -n 1\n", "\n", "!cat {xml_fn} | grep -A 100 \"handwritten-part\" | grep \"" ] }, { "cell_type": "markdown", "metadata": { "id": "MX9n-Zed8G_T" }, "source": [ "# Lab 07: Deployment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What You Will Learn\n", "\n", "- How to convert PyTorch models into portable TorchScript binaries\n", "- How to use `gradio` to make a simple demo UI for your ML-powered applications\n", "- How to split out a model service from the frontend and spin up a publicly accessible application" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "45D6GuSwvT7d" }, "outputs": [], "source": [ "lab_idx = 7\n", "\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " bootstrap.run = False # change to True re-run setup\n", " \n", "!pwd\n", "%ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pzi8qYKI-njP" }, "outputs": [], "source": [ "from IPython.display import display, HTML, IFrame\n", "\n", "full_width = True\n", "frame_height = 720 # adjust for your screen\n", "\n", "if full_width: # if we want the notebook to take up the whole width\n", " # add styling to the notebook's HTML directly\n", " display(HTML(\"\"))\n", " display(HTML(\"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "\n", "IFrame(src=\"https://fsdl.me/2022-lab-07-video-embed\", width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "SAw7BEI_sCZZ" }, "source": [ "# Making the model portable" ] }, { "cell_type": "markdown", "metadata": { "id": "8zL0K2Xe-MWJ" }, "source": [ "While training the model,\n", "we've saved checkpoints and stored them locally\n", "and on W&B.\n", "\n", "From these checkpoints, we can reload model weights\n", "and even restart training if we are in or can recreate\n", "the model development environment.\n", "\n", "We could directly deploy these checkpoints into production,\n", "but they're suboptimal for two reasons.\n", "\n", "First, as the name suggests,\n", "these \"checkpoints\" are designed for serializing\n", "state at a point of time in training.\n", "\n", "That means they can include lots of information\n", "not relevant during inference,\n", "e.g. optimizer states like running average gradients.\n", "\n", "Additionally, the model development environment\n", "is much more heavyweight than what we need during inference.\n", "\n", "For example, we've got Lightning for training models\n", "and W&B for tracking training runs.\n", "\n", "These in turn incur dependencies on lots of heavy data science libraries.\n", "\n", "We don't need this anymore -- we just want to run the model.\n", "\n", "These are effectively \"compiler tools\", which our runtime model doesn't need.\n", "\n", "So we need a new model binary artifact for runtime\n", "that's leaner and more independent.\n", "\n", "For this purpose, we use TorchScript." ] }, { "cell_type": "markdown", "metadata": { "id": "0bMPqKDjs623" }, "source": [ "## Compiling models to TorchScript" ] }, { "cell_type": "markdown", "metadata": { "id": "7d9EmZ0j_AQF" }, "source": [ "Torch has two main facilities for creating\n", "more portable model binaries:\n", "_scripting_ and _tracing_." ] }, { "cell_type": "markdown", "metadata": { "id": "h9PVzwjQ_YHg" }, "source": [ "Scripting produces a binary that combines\n", "constant `Tensor` values\n", "(like weights and positional embeddings)\n", "with a program that describes how to use them.\n", "\n", "The result is a program that creates a dynamic graph,\n", "as does a normal PyTorch program,\n", "but this program is written in a\n", "sub-dialect of Python called\n", "_TorchScript_.\n", "\n", "The [TorchScript sub-dialect of Python](https://pytorch.org/docs/stable/jit_language_reference.html#language-reference)\n", "is more performant\n", "and can even be run without a Python interpreter.\n", "\n", "For example, TorchScript programs can be executed in pure C++\n", "[using LibTorch](https://pytorch.org/tutorials/advanced/cpp_export.html).\n", "\n", "You can read more in the documentation for the primary method\n", "for scripting models, `torch.jit.script`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "h1VtGt_Xj_H7" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "torch.jit.script??" ] }, { "cell_type": "markdown", "metadata": { "id": "tUOm7G9ESi4s" }, "source": [ "The primary alternative to scripting is _tracing_,\n", "which runs the PyTorch module on a specific\n", "set of inputs and records, or \"traces\",\n", "the compute graph.\n", "\n", "You can read more about it in the documentation for the primary method\n", "for tracing models, `torch.jit.trace`,\n", "or just read the quick summary and comparison below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pn3QLOFNjuOa" }, "outputs": [], "source": [ "torch.jit.trace??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tracing versus Scripting for TorchScript" ] }, { "cell_type": "markdown", "metadata": { "id": "uP4TfihfBw9z" }, "source": [ "The traced program is generally faster than the scripted version,\n", "for models that are compatible with both tracing and scripting.\n", "\n", "Tracing produces a static compute graph,\n", "which means all control flow\n", "(`if`s or `for` loops)\n", "are effectively inlined.\n", "\n", "As written, our text recognizer has a loop with conditional breaking -- fairly typical for Transformers in autoregressive mode --\n", "so it isn't compatible with tracing.\n", "\n", "Furthermore, the static compute graph includes concrete choices of operations,\n", "e.g. specific CUDA kernels if tracing is run on the GPU.\n", "\n", "If you try to run the traced model on a system that doesn't support those kernels,\n", "it will crash.\n", "That means tracing must occur in the target deployment environment.\n", "\n", "Scripted models are much more portable, at the cost of both slower runtimes\n", "for a fixed hardware target and of some restrictions on how dynamic the Python code can be.\n", "\n", "We don't find the restrictions scripting places on Python code to be too onerous\n", "and in our experience, the performance gains are not worth the extra effort\n", "until the team size is larger,\n", "model serving hardware and strategy is more mature,\n", "and model release cycles are slower.\n", "\n", "For an alternative perspective that's more in favor of tracing\n", "and walks through how to mix-and-match scripting\n", "and tracing for maximum flexibility and performance, see\n", "[this blogpost](https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/)\n", "from\n", "[Detectron2](https://ai.facebook.com/blog/-detectron2-a-pytorch-based-modular-object-detection-library-/)\n", "dev Yuxin Wu." ] }, { "cell_type": "markdown", "metadata": { "id": "cDARv-GdqtET" }, "source": [ "Choosing just one of scripting or tracing\n", "means we can use a high-level method\n", "from PyTorch Lightning,\n", "`to_torchscript`,\n", "to produce our scripted model binary\n", "and we don't need to touch our model code." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "udvnx7sBBklY" }, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "\n", "\n", "pl.LightningModule.to_torchscript??" ] }, { "cell_type": "markdown", "metadata": { "id": "iXftpJBizrM6" }, "source": [ "## Alternatives to TorchScript" ] }, { "cell_type": "markdown", "metadata": { "id": "QvFh_SW8v4p6" }, "source": [ "Though it has some sharp edges,\n", "TorchScript is a relatively easy to use tool\n", "for compiling neural networks written in PyTorch.\n", "\n", "If you're willing to tolerate more sharp edges,\n", "e.g. limited support for certain ops\n", "and a higher risk of subtle differences in behavior, the\n", "[Open Neural Network eXchange](https://onnx.ai/)\n", "format, ONNX, is a compilation target for\n", "[a wide variety of DNN libraries](https://onnx.ai/supported-tools.html),\n", "from `sklearn` and MATLAB\n", "to PyTorch and Hugging Face.\n", "\n", "A high-level utility for conversion to ONNX is also included\n", "in PyTorch Lightning, `pl.LightningModule.to_onnx`.\n", "\n", "Because it is framework agnostic,\n", "there's more and more varied tooling around ONNX,\n", "and it has smoother paths to\n", "compilation targets that can run DNNs\n", "at the highest possible speeds,\n", "like\n", "[NVIDIA's TensorRT](https://developer.nvidia.com/tensorrt)\n", "or\n", "[Apache TVM](https://tvm.apache.org/2017/08/17/tvm-release-announcement).\n", "\n", "TensorRT is the model format used in the\n", "[Triton Inference Server](https://github.com/triton-inference-server/server),\n", "a sort of \"kubernetes for GPU-accelerated DNNs\"\n", "that is, as of 2022,\n", "the state of the art in running deep networks\n", "at maximum throughput on server-grade GPUs.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "36dKPerevkhZ" }, "source": [ "## A simple script for compiling and staging models" ] }, { "cell_type": "markdown", "metadata": { "id": "93pc-NLrBR1A" }, "source": [ "To recap, our model staging workflow,\n", "which does the hand-off between training and production, looks like this:\n", "\n", "1. Get model weights and hyperparameters\n", "from a tracked training run in W&B's cloud storage.\n", "2. Reload the model as a `LightningModule` using those weights and hyperparameters.\n", "3. Call `to_torchscript` on it.\n", "4. Save that result to W&B's cloud storage.\n", "\n", "We provide a simple script to implement this process:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gqgiWO0tFktU" }, "outputs": [], "source": [ "%run training/stage_model.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "i4qEqMRkFsd4" }, "source": [ "Here in this notebook,\n", "rather than training or scripting a model ourselves,\n", "we'll just `--fetch`\n", "an already trained and scripted model binary:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c2wfjLmRDwrH" }, "outputs": [], "source": [ "%run training/stage_model.py --fetch --entity=cfrye59 --from_project=fsdl-text-recognizer-2021-training" ] }, { "cell_type": "markdown", "metadata": { "id": "I0uNnvjkCZzX" }, "source": [ "Note that we can use the metadata of the staged model\n", "to find the training run that generated the model weights.\n", "It requires two graph hops:\n", "find the run that created the staged TorchScript model\n", "then in that run,\n", "find the model checkpoint artifact\n", "and look for the run that created it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "E9zJg44hCjRv" }, "outputs": [], "source": [ "from IPython import display\n", "\n", "\n", "staged_model_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/artifacts/prod-ready/paragraph-text-recognizer/3e07efa34aec61999c5a/overview\"\n", "\n", "IFrame(staged_model_url, width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we're deploying our first model,\n", "this doesn't feel that important --\n", "it's easy enough to find the training runs\n", "we've executed and connect them to the model in production.\n", "\n", "But as we train and release more models,\n", "this information will become harder to find\n", "and automation and API access will become more important.\n", "\n", "This will be especially true if we adopt more sophisticated rollout strategies,\n", "like A/B testing or canarying,\n", "as the application matures.\n", "\n", "Our system here is not robust enough to be Enterprise Grade™️ --\n", "marking models as \"in production\" is manual\n", "and there are no access control planes built in --\n", "but at least the information is preserved." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running our more portable model via a CLI" ] }, { "cell_type": "markdown", "metadata": { "id": "X7d2WHSCHHHP" }, "source": [ "Now that our TorchScript model binary file is present,\n", "we can spin up our text recognizer\n", "with much less code.\n", "\n", "We just need a compatible version of PyTorch\n", "and methods to convert\n", "our generic data types\n", "(images, strings)\n", "to and from PyTorch `Tensor`s.\n", "\n", "We can put all this together in\n", "a single light-weight object,\n", "the `ParagraphTextRecognizer` class:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZGXZep-nDiDk" }, "outputs": [], "source": [ "from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizer\n", "\n", "\n", "ParagraphTextRecognizer??\n", "\n", "ptr = ParagraphTextRecognizer()" ] }, { "cell_type": "markdown", "metadata": { "id": "uwVo6BoeGmTW" }, "source": [ "And from there,\n", "we can start running on images\n", "and inferring the text that they contain:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CMZlfIoeG3hy" }, "outputs": [], "source": [ "from IPython.display import Image\n", "\n", "example_input = \"text_recognizer/tests/support/paragraphs/a01-077.png\"\n", "\n", "print(ptr.predict(example_input))\n", "Image(example_input)" ] }, { "cell_type": "markdown", "metadata": { "id": "I6AHq1TH44Jq" }, "source": [ "As usual,\n", "we write our Python code\n", "so that it can be imported as a module\n", "and run in a Jupyter notebook,\n", "for documentation and experimentation,\n", "and we make it executable as a script\n", "for easier automation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "igY7sd8eGGI3" }, "outputs": [], "source": [ "%run text_recognizer/paragraph_text_recognizer.py --help\n", "\n", "%run text_recognizer/paragraph_text_recognizer.py {example_input}" ] }, { "cell_type": "markdown", "metadata": { "id": "MvYmSN0rE2BP" }, "source": [ "Notice that the `filename` here can be a local file, a URL, or even a cloud storage URI.\n", "\n", "Rather than writing the logic for handling these different cases,\n", "we use the\n", "[`smart_open` library](https://pypi.org/project/smart-open/)." ] }, { "cell_type": "markdown", "metadata": { "id": "3WQ-P16VC94R" }, "source": [ "## Testing our model development pipeline" ] }, { "cell_type": "markdown", "metadata": { "id": "0kVq2iBJDZH5" }, "source": [ "Creating models is _the_ critical function of our code base,\n", "so it's important that we test it,\n", "at the very least with \"smoke tests\" that let us know\n", "if the code is completely broken.\n", "\n", "Right now we have tests for data loading and model training,\n", "but no tests for end-to-end model development,\n", "which combines data loading, model training, and model compilation.\n", "\n", "So we add a simple model development test\n", "that trains a model for a very small number of steps\n", "and then runs our staging script.\n", "\n", "This model development test script returns an error code (`exit 1`) if the process of\n", "building a model fails (`\"$FAILURE\" = true`).\n", "\n", "We use\n", "[the `||` operator](https://www.unix.com/shell-programming-and-scripting/42417-what-does-mean-double-pipe.html)\n", "to set the `FAILURE` variable to `true` if any of the key commands in model development fail." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XPkwFxklDA5V", "scrolled": false }, "outputs": [], "source": [ "!cat training/tests/test_model_development.sh" ] }, { "cell_type": "markdown", "metadata": { "id": "pQ21iRDqFvxj" }, "source": [ "As a next step to improve the coverage of this test,\n", "we might compare the model's outputs\n", "on the same inputs before and after compilation." ] }, { "cell_type": "markdown", "metadata": { "id": "hyXZhgqEvfe9" }, "source": [ "### Cleaning up artifacts" ] }, { "cell_type": "markdown", "metadata": { "id": "l22DqhC4GIJT" }, "source": [ "The final few lines of the testing script mention\n", "\"`selecting for deletion`\" some artifacts." ] }, { "cell_type": "markdown", "metadata": { "id": "EbIW5okFGQv7" }, "source": [ "As we incorporate more of our code into testing\n", "and develop more models,\n", "the amount of information we are storing on W&B increases.\n", "\n", "We're already uploading model checkpoints, several gigabytes per model training run,\n", "and now we're also looking at uploading several hundred megabytes\n", "of model data per execution of our test." ] }, { "cell_type": "markdown", "metadata": { "id": "T7aBCfpuJVJV" }, "source": [ "Artifact storage is free up to 100GB,\n", "but storing more requires a paid account.\n", "\n", "That means it literally pays to clean up after ourselves.\n", "\n", "We use a very simple script to select certain artifacts for deletion.\n", " \n", "> ⚠️ **Don't use this untested demonstration script in important environments!** ⚠️\n", "We include options for `-v`erbose output and a `--dryrun` mode,\n", "which are both critical for destructive actions that have access\n", "to model weights that might cost $1000s to produce.\n", "\n", "See the `--help` below for more on cleaning up artifacts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8hSqzRplITVB" }, "outputs": [], "source": [ "%run training/cleanup_artifacts.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "BfB38ywTJDMT" }, "source": [ "## Tuning inference performance on CPU and GPU" ] }, { "cell_type": "markdown", "metadata": { "id": "zau0MRr1FPw-" }, "source": [ "Apart from compilation to TorchScript,\n", "the biggest difference for running the model in production\n", "is that now all of our operations occur on the CPU.\n", "\n", "This is a surprising feature of DNN deployment\n", "that's worth thinking about in detail.\n", "\n", "Why isn't it a given that deep network inference\n", "runs on GPUs, when that's so critical for deep network training?\n", "\n", "First,\n", "not many web applications use GPUs,\n", "so there aren't nearly as many good tools and techniques\n", "for deplyoing GPU-backed services.\n", "\n", "But there's another, deeper reason:\n", "GPUs are not as easy to run efficiently\n", "during inference as they are in training.\n", "\n", "In training,\n", "we use static or synthetic datasets\n", "and our training code is in charge\n", "of the query patterns.\n", "\n", "In particular,\n", "we can request exactly as many inputs\n", "as we want to produce a batch\n", "that makes optimal use\n", "of our expensive GPUs.\n", "\n", "In production, requests arrive independently,\n", "according to the whims of our users.\n", "\n", "This makes batching challenging,\n", "and by far the simplest service architecture\n", "just runs on each request as it arrives.\n", "\n", "But that tanks GPU utilization.\n", "\n", "GPUs are highly parallel computers,\n", "and batch is the easiest dimension to parallelize on --\n", "for example, we load the model weights into memory once,\n", "use them, and then release the memory.\n", "\n", "The cell below\n", "compares two traces\n", "for a GPU-accelerated\n", "Text Recognizer model running\n", "on a single input and on a batch.\n", "\n", "For a simple summary,\n", "you can compare the two profiles in TensorBoard\n", "([batch size 1 here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-labs-lab05_training/runs/1vj48h6j/tensorboard?workspace=user-cfrye59),\n", "[batch size 16 here](https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-training/runs/67j1qxws/tensorboard?workspace=user-cfrye59)).\n", "\n", "GPU utilization,\n", "our baseline metric for model performance,\n", "is under 50% with batch size 1,\n", "as compared to >90% with batch size 16,\n", "which fills up GPU RAM.\n", "\n", "You can also look through the traces for more details:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B_NZPLWa-ZVP" }, "outputs": [], "source": [ "trace_comparison_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2022-labs-lab05_training/reports/Trace-Comparison-Batch-Size-16-vs-1--VmlldzoyNTg2MTU4\"\n", "\n", "print(trace_comparison_url)\n", "IFrame(src=trace_comparison_url, width=\"100%\", height=frame_height)" ] }, { "cell_type": "markdown", "metadata": { "id": "_6_U1OyU-Vsi" }, "source": [ "But performance during inference is not as simple \n", "as just \"maximize GPU utilization\".\n", "\n", "In particular, throughput for the GPU with batch size 16\n", "is over 2x better,\n", "one example per 8 ms vs\n", "one example per 40 ms,\n", "but latency is much worse.\n", "\n", "It takes 140ms to complete the batch of size 16.\n", "In the intervening time no examples are completed,\n", "and all 16 users are waiting on a response.\n", "\n", "For comparison,\n", "running one example at a time\n", "would get the first user's result\n", "in just 40 ms,\n", "but the total processing time for all 16 examples would be\n", "640 ms.\n", "\n", "For user experience, latency is critical,\n", "but for making the most efficient use of hardware,\n", "throughput is generally more important.\n", "\n", "During training, we care much less about latency\n", "and much more about computing gradients as fast as possible,\n", "so we aim for larger batch sizes.\n", "\n", "Because of the need for efficient use of hardware,\n", "running on single inputs isn't always feasible.\n", "\n", "The usual solution is to run a queue,\n", "which collects up enough requests for a batch.\n", "\n", "One of the easiest ways to do this as of writing in September 2022 is to use\n", "[`cog` by Replicate](https://github.com/replicate/cog),\n", "which both solves difficult issues with containerizing\n", "models with GPU acceleration \n", "and includes, as a beta feature, a built-in Redis queue\n", "for batching requests and responses.\n", "\n", "But note that we can't just run a queue that waits for,\n", "say, 16 user requests\n", "to build up, then runs them all.\n", "If 15 requests come in at once,\n", "but then no requests come for an hour,\n", "all 15 users will be waiting for an hour\n", "for their responses --\n", "much worse than just waiting a few hundred extra milliseconds!\n", "\n", "We need to make sure the queue flushes after a certain amount of time,\n", "regardless of how many requests it has received,\n", "complicating our implementation.\n", "\n", "Running single inputs on GPUs\n", "and running a naive queue\n", "are two different ways it's easy to accidentally tank latency\n", "while pursuing efficiency,\n", "at least for some fraction of cases.\n", "\n", "So we stick with CPU inference." ] }, { "cell_type": "markdown", "metadata": { "id": "te-CYidTslPo" }, "source": [ "# Building a simple model UI" ] }, { "cell_type": "markdown", "metadata": { "id": "4kGXwQvjJq32" }, "source": [ "With compilation,\n", "we've moved from a model that can only run\n", "in a very special environment\n", "and with lots of support code\n", "into something lightweight\n", "that runs with a simple CLI.\n", "\n", "If we want users to send data to our model\n", "and get useful predictions out,\n", "we need to create a UI.\n", "\n", "But a CLI is not a UI --\n", "it's at best the foundation out of which a UI is built.\n", "\n", "This is not just a concern once the model is finished:\n", "a UI is an incredible tool for model debugging.\n", "\n", "It's hard to overstate the difference between\n", "a static, CLI or code-writing workflow\n", "for sending information to a model\n", "and an interactive interface.\n", "\n", "When your model is easily accessible on a mobile phone,\n", "when you can copy-paste text from elsewhere on your machine or the internet,\n", "or when you can upload arbitrary files,\n", "the whole range of possible inputs becomes clear\n", "in a way that's very hard to replicate with fixed data sets." ] }, { "cell_type": "markdown", "metadata": { "id": "S163btePLB1K" }, "source": [ "Unfortunately, creating a GUI from scratch is not easy,\n", "especially in Python.\n", "\n", "The best tool for GUIs is the browser,\n", "but the lingua franca of the browser\n", "is JavaScript\n", "([for now](https://webassembly.org/)).\n", "\n", "As full stack deep learning engineers,\n", "we're already writing Python with C/C++ acceleration,\n", "we're gluing scripts together with Bash,\n", "and we need to know enough SQL to talk to databases.\n", "\n", "Do we now need to learn front-end web development too?" ] }, { "cell_type": "markdown", "metadata": { "id": "oSeBo0MzL0H9" }, "source": [ "In the long term, it's a good investment,\n", "and we recommend\n", "[The Odin Project](https://www.theodinproject.com/),\n", "a free online course and community for learning web development.\n", "\n", "Their\n", "[Foundations course](https://www.theodinproject.com/paths/foundations/courses/foundations#html-foundations),\n", "starting from HTML foundations and proceeding\n", "through basic CSS\n", "and JavaScript,\n", "is a great way to dip your toes in\n", "and learn enough about building websites and UIs\n", "in the browser to be dangerous." ] }, { "cell_type": "markdown", "metadata": { "id": "q-7pJcsCL_84" }, "source": [ "In the short term,\n", "we write our frontends in Python libraries\n", "that effectively write the frontend JavaScript/CSS/HTML\n", "for us.\n", "\n", "For the past few years,\n", "[Streamlit](https://streamlit.io/)\n", "has been a popular choice for the busy Python data scientist.\n", "\n", "It remains a solid choice,\n", "and tooling for building complex apps with Streamlit is more mature." ] }, { "cell_type": "markdown", "metadata": { "id": "xey5gzr5tV51" }, "source": [ "We use the\n", "[`gradio` library](https://gradio.app/),\n", "which includes a simple API for wrapping\n", "a single Python function into a frontend\n", "in addition to a less mature, lower-level API\n", "for building apps more flexibly.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2XvUr7irMHQ6" }, "source": [ "This iteration of the FSDL codebase\n", "includes a new module,\n", "`app_gradio`,\n", "that makes a simple UI for the Text Recognizer\n", "using `gradio`.\n", "\n", "The core component is a script,\n", "`app_gradio/app.py`,\n", "that can be used to spin up our model and UI\n", "from the command line:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w2Ra8ot292XX" }, "outputs": [], "source": [ "%run app_gradio/app.py --help" ] }, { "cell_type": "markdown", "metadata": { "id": "J9bP3zFo9_YY" }, "source": [ "But one very nice feature of `gradio`\n", "is that it is designed to run as easily\n", "from the notebook as from the command line.\n", "\n", "Let's import the contents of `app.py`\n", "and take a look,\n", "then launch our UI." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vryi5r6gDj6D" }, "outputs": [], "source": [ "from app_gradio import app\n", "\n", "\n", "app.make_frontend??\n", "frontend = app.make_frontend(ptr.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use `gradio`'s high-level API, `gr.Interface`,\n", "to build a UI by wrapping our `ptr.predict` function,\n", "defining its inputs\n", "(an `Image`)\n", "and outputs\n", "(a `TextBox`),\n", "and specifying some formatting\n", "and styling choices." ] }, { "cell_type": "markdown", "metadata": { "id": "m0HxOukBNn13" }, "source": [ "\n", "\n", "We can spin up our UI with the `.launch` method,\n", "and now we can interact\n", "with the model from inside the notebook.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XoVFtGbuDlTL" }, "outputs": [], "source": [ "frontend.launch(share=True, width=\"100%\")" ] }, { "cell_type": "markdown", "metadata": { "id": "okcoAW7sM13h" }, "source": [ "For 72 hours, we can also access the model over the public internet\n", "using a URL provided by `gradio`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x5pEhMECNIT6" }, "outputs": [], "source": [ "print(frontend.share_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "LYfi-lZqNNZd" }, "source": [ "You can point your browser to that URL\n", "to see what the model looks like as a full-fledged web application,\n", "instead of a widget inside the notebook." ] }, { "cell_type": "markdown", "metadata": { "id": "2L5uZCJlOGi4" }, "source": [ "In addition to this UI,\n", "`gradio` also creates a simple REST API,\n", "so we can make requests\n", "from outside the browser,\n", "programmatically,\n", "and get responses." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XOngmAWvQnqg" }, "outputs": [], "source": [ "%env API_URL={frontend.share_url + \"/api\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "cj6XSur7Nlzf" }, "source": [ "We can see the details of the API by clicking\n", "\"view api\" at the bottom of the Gradio interface.\n", "\n", "In particular,\n", "we can see that the API expects image data in\n", "[base64 format](https://developer.mozilla.org/en-US/docs/Glossary/Base64),\n", "which encodes binary data as ASCII text\n", "so that it can be sent over interfaces that expect ASCII text." ] }, { "cell_type": "markdown", "metadata": { "id": "igeFyT84WqqG" }, "source": [ "The line below encodes an image with the `base64` utility,\n", "packages it into the appropriate JSON format\n", "and uses `echo` to pipe it into a `curl` command.\n", "\n", "`curl` can be used to make requests to web services at URLs\n", "-- here `${API_URL}/predict` --\n", "of specific types\n", "-- here `POST` --\n", "that include `-d`ata\n", "and `-H`eaders identifying the format of the data.\n", "\n", "The response is returned as\n", "[string-formatted JSON](https://developer.mozilla.org/en-US/docs/Learn/JavaScript/Objects/JSON)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_nmRbYQCOd3t" }, "outputs": [], "source": [ "response, = ! \\\n", " (echo -n '{ \"data\": [\"data:image/png;base64,'$(base64 -w0 -i text_recognizer/tests/support/paragraphs/a01-077.png)'\"] }') \\\n", " | curl -s -X POST \"${API_URL}/predict\" -H 'Content-Type: application/json' -d @-\n", " \n", "response" ] }, { "cell_type": "markdown", "metadata": { "id": "tLy9z593X4_o" }, "source": [ "JSON, short for \"JavaScript Object Notation\",\n", "is effectively the standard for representing dictionaries\n", "when sharing information between applications\n", "that may be written in different languages.\n", "\n", "With the standard library's `json.loads`,\n", "we can convert the response into a Python dictionary\n", "and then access the response `data` within." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GL4L8o4KRQLx" }, "outputs": [], "source": [ "import json\n", "\n", "\n", "print(json.loads(response)[\"data\"][0])" ] }, { "cell_type": "markdown", "metadata": { "id": "rhOc0fgrRtuO" }, "source": [ "Importantly, the `echo | curl` command\n", "does not need to be run from the same machine that is running the model --\n", "that's another big win for this UI over the CLI script we ran previously.\n", "\n", "Try running the command from your own machine,\n", "if you are running OS X or Linux,\n", "and see if you can get a response.\n", "\n", "Don't forget to define the `API_URL` environment variable on your machine\n", "and download the image file,\n", "`text_recognizer/tests/support/paragraphs/a01-077.png`,\n", "changing the path if needed." ] }, { "cell_type": "markdown", "metadata": { "id": "cd1UZiM3ZVWz" }, "source": [ "Once you're done,\n", "turn off the Gradio interface by running the `.close` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mVyv6KjxJhEb" }, "outputs": [], "source": [ "frontend.close()" ] }, { "cell_type": "markdown", "metadata": { "id": "qnJpCdI7SHiX" }, "source": [ "## Testing our UI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've added a lot of new functionality here,\n", "and some of it is critical to our application.\n", "\n", "The surface area is too large and\n", "the components too complex for testing in depth\n", "to be worth the investment --\n", "do we really want to set up a\n", "[headless browser](https://www.browserstack.com/guide/what-is-headless-browser-testing)\n", "or similar mock test to check whether our README is being loaded properly?\n", "\n", "So once again, we pick the minimal test that checks whether\n", "the core functionality is working:\n", "we spin up our frontend and ping the API,\n", "making sure we get back a\n", "[`200 OK`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/200)\n", "response, indicating that at least the server thinks everything is fine." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!cat app_gradio/tests/test_app.py" ] }, { "cell_type": "markdown", "metadata": { "id": "IwUhy-swZndq" }, "source": [ "## Start here, finish anywhere" ] }, { "cell_type": "markdown", "metadata": { "id": "FTKMGCasMznl" }, "source": [ "You may be concerned:\n", "is `gradio` a children's toy?\n", "am I painting myself into a corner\n", "by using such a high-level framework and doing web development in Python?\n", "shouldn't I be using Ruby On Rails/Angular/React/WhateversNext.js?\n", "\n", "DALL-E Mini, now\n", "[crAIyon](https://www.craiyon.com/),\n", "began its life as\n", "[a Gradio app](https://huggingface.co/spaces/dalle-mini/dalle-mini)\n", "built by FSDL alumnus\n", "[Boris Dayma](https://twitter.com/borisdayma).\n", "\n", "Gradio and similar tools\n", "are critical for quickly getting to an MVP\n", "and getting useful feedback on your model.\n", "\n", "Expend your engineering effort on data and training,\n", "not frontend interface development,\n", "until you're sure you've got something people want to use." ] }, { "cell_type": "markdown", "metadata": { "id": "8BpPtj6tsP-Y" }, "source": [ "# Wrapping a model into a model service" ] }, { "cell_type": "markdown", "metadata": { "id": "ButF0a6PSbMi" }, "source": [ "We've got an interactive interface for our model\n", "that we can share with friends, colleagues,\n", "potential users, or stakeholders,\n", "which is huge.\n", "\n", "But we have a problem:\n", "our model is running in the same place as our frontend.\n", "\n", "This is simple,\n", "but it ties too many things together.\n", "\n", "First, it ties together execution of the two components.\n", "\n", "If the model has a heart attack due to misformatted inputs\n", "or some mysterious DNN bug,\n", "the server goes down.\n", "The same applies in reverse --\n", "the only API for the model is provided by `gradio`,\n", "so a frontend issue means the model is inaccessible.\n", "\n", "Additionally, it ties together dependencies,\n", "since our server and our model are in the same\n", "environment.\n", "\n", "Lastly, it ties together the hardware used to run our\n", "server and our model.\n", "\n", "That's bad because the server and the model scale differently.\n", "Running the server at scale has different memory and computational requirements\n", "than does running the model at scale." ] }, { "cell_type": "markdown", "metadata": { "id": "HNoMc7fRcETy" }, "source": [ "We could just run another server --\n", "even writing it in Gradio if we wanted! --\n", "for the model.\n", "This is common with GPU inference,\n", "especially when doing queueing, cacheing,\n", "and other advanced techniques for improving\n", "model efficiency and latency.\n", "\n", "But that's potentially expensive --\n", "we're running two machines,\n", "which costs twice as much.\n", "\n", "Furthermore, this setup is harder to scale \"horizontally\".\n", "\n", "We'll pretty quickly need a solution for auto-scaling\n", "our two servers independently,\n", "e.g. directly in a container orchestration service, like\n", "[Kubernetes](https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/),\n", "or in a managed version of the same, like\n", "[Elastic Kubernetes Service](https://aws.amazon.com/eks/),\n", "or with an infrastructure automation tool, like\n", "[Terraform](https://www.terraform.io/)." ] }, { "cell_type": "markdown", "metadata": { "id": "0WI0H6Imcz_h" }, "source": [ "Luckily, there is an easier way, because our model service-plus-UI\n", "combo fits into a common pattern.\n", "\n", "We have a server that we want to be up all the time,\n", "ready to take requests,\n", "but we really only need\n", "the model service to run when a request hits.\n", "\n", "And apart from its environment (which includes the weights),\n", "the model only needs the request in order to produce a result.\n", "\n", "It does not need to hold onto any information in between executions --\n", "it is _stateless_.\n", "\n", "This pattern is common enough that all cloud providers\n", "offer a solution that takes the pain out of scaling\n", "the stateless component:\n", "\"serverless cloud functions\",\n", "so named because\n", "- they are run intermittently, rather than 24/7, like a server.\n", "- they are run on cloud infrastructure.\n", "- they are, as in\n", "[purely functional programming](https://en.wikipedia.org/wiki/Purely_functional_programming)\n", "or in mathematics, \"pure\" functions of their inputs,\n", "with no concept of state." ] }, { "cell_type": "markdown", "metadata": { "id": "eE_FhWxLhhxG" }, "source": [ "We use AWS's serverless offering,\n", "[AWS Lambda](https://aws.amazon.com/lambda/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xw3Una-yJ_mP" }, "outputs": [], "source": [ "from api_serverless import api\n", "\n", "api??" ] }, { "cell_type": "markdown", "metadata": { "id": "FGAeXmfFiYOi" }, "source": [ "Our main function here, `api.handler`, wraps `ParagraphTextRecognizer.predict`.\n", "\n", "Effectively, `api.handler` maps HTTP requests (`event`s) with AWS's canonical format\n", "to a format our `ParagraphTextRecognizer` understands,\n", "then converts the text recognizer's output into something\n", "that AWS understands.\n", "\n", "Deploying models as web services is an exercise in taking\n", "the Tensor-to-Tensor-mappings we work with in model development\n", "and wrapping them so that they run in the\n", "JSON-to-JSON-mapping world of web services." ] }, { "cell_type": "markdown", "metadata": { "id": "TDMPQKXqr7pS" }, "source": [ "## Talking to a model service" ] }, { "cell_type": "markdown", "metadata": { "id": "V41-UiMct92x" }, "source": [ "Setting up a serverless function on AWS requires an account\n", "(which requires putting down a credit card)\n", "and configuration of permissions\n", "(which is error-prone).\n", "\n", "If you want to see how that process works,\n", "check out our\n", "[\"bonus notebook\" on serverless deployment on AWS Lambda](https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/notebooks/lab99_serverless_aws.ipynb).\n", "Heads up: it uses Docker,\n", "which means it's not compatible with Google Colab.\n", "\n", "So we'll skip that step and,\n", "like Julia Child or Martha Stewart, check out\n", "[one that was prepared earlier](https://tvtropes.org/pmwiki/pmwiki.php/Main/OneIPreparedEarlier).\n", "\n", "The cell below sends a request\n", "to a serverless cloud function running on the FSDL AWS account.\n", "\n", "This request is\n", "much like the one we sent to the API provided by `gradio`,\n", "but we here construct and send it in Python,\n", "using the `requests` library,\n", "rather than operating from the command line.\n", "\n", "When playing around with an API,\n", "writing requests and parsing responses \"by hand\"\n", "in the command line is helpful,\n", "but once we're working on real use cases for the API,\n", "we'll want to use higher-level libraries\n", "with good code quality and nice integrations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "76HwEP2Vzz3F" }, "outputs": [], "source": [ "import json\n", "\n", "from IPython.display import Image\n", "import requests # the preferred library for writing HTTP requests in Python\n", "\n", "lambda_url = \"https://3akxma777p53w57mmdika3sflu0fvazm.lambda-url.us-west-1.on.aws/\"\n", "image_url = \"https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png\"\n", "\n", "headers = {\"Content-Type\": \"application/json\"} \n", "payload = json.dumps({\"image_url\": image_url})\n", "\n", "response = requests.post( # we POST the image to the URL, expecting a prediction as a response\n", " lambda_url, data=payload, headers=headers)\n", "pred = response.json()[\"pred\"] # the response is also json\n", "\n", "print(pred)\n", "\n", "Image(url=image_url, width=512)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before deploying a service like this one,\n", "it's important to check how well it handles different traffic volumes and patterns.\n", "This process is known as _load-testing_.\n", "\n", "For a quick tutorial on some basic tooling and a run-through of\n", "results from load-testing the FSDL Text Recognizer on AWS Lambda, see\n", "[this \"bonus notebook\" on load-testing](https://fsdl.me/loadtesting-colab)." ] }, { "cell_type": "markdown", "metadata": { "id": "bZQ2Dt4URN9o" }, "source": [ "## Local in the front, serverless in the back" ] }, { "cell_type": "markdown", "metadata": { "id": "XMXWTHt4Pxpr" }, "source": [ "The primary \"win\" here\n", "is that we don't need to run\n", "the frontend UI server\n", "and the backend model service in\n", "the same place.\n", "\n", "For example,\n", "we can run a Gradio app locally\n", "but send the images to the serverless function\n", "for prediction.\n", "\n", "Our `app_gradio` implementation supports this via the `PredictorBackend`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4qZ1K0fwOtYK" }, "outputs": [], "source": [ "serverless_backend = app.PredictorBackend(url=lambda_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "5NVVU2JEPSpy" }, "source": [ "Previously, our `PredictorBackend`\n", "was just a wrapper around the `ParagraphTextRecognizer` class.\n", "\n", "By passing a URL,\n", "we switch to sending data elsewhere via an HTTP request.\n", "\n", "This is done by the\n", "`_predict_from_endpoint` method,\n", "which runs effectively the same code we used\n", "to talk to the model service in the cell above." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HtSppJq2O_B_" }, "outputs": [], "source": [ "serverless_backend._predict_from_endpoint??" ] }, { "cell_type": "markdown", "metadata": { "id": "DKA68zxUUO9e" }, "source": [ "The frontend doesn't care where the inference is getting done or how.\n", "\n", "A `gradio.Interface`\n", "just knows there's a Python function that it invokes and then \n", "waits for outputs from.\n", "\n", "Here, that Python function\n", "makes a request to the serverless backend,\n", "rather than running the model.\n", "\n", "Go ahead and try it out!\n", "\n", "You won't notice a difference,\n", "except that the machine you're running this notebook on\n", "no longer runs the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WEkMzohnOcK0" }, "outputs": [], "source": [ "frontend_serverless_backend = app.make_frontend(serverless_backend.run)\n", "\n", "frontend_serverless_backend.launch(share=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "XytXrIWVuRFu" }, "source": [ "# Serving a `gradio` app with `ngrok`" ] }, { "cell_type": "markdown", "metadata": { "id": "2i64HrL1wa7F" }, "source": [ "We've now got a model service and a web server\n", "that we can stand up and scale independently,\n", "but we're not quite done yet.\n", "\n", "First, our URL is controlled by Gradio.\n", "\n", "Very quickly once we leave the territory of a minimal demo,\n", "we'll want that URL to be branded.\n", "\n", "Relatedly,\n", "you may have noticed messages indicating that the public URL\n", "from Gradio is only good for 72 hours.\n", "\n", "That means we'd have to reset our frontend\n", "and share a new URL every few days." ] }, { "cell_type": "markdown", "metadata": { "id": "clsPvqtJu0V0" }, "source": [ "For projects that are mostly intended as public demos,\n", "you might follow the advice from those printed warnings\n", "and use\n", "[Hugging Face Spaces](https://huggingface.co/docs/hub/spaces)\n", "for free, permanent hosting.\n", "\n", "This relieves you of the burden of keeping the frontend server running.\n", "\n", "However, note that this requires you to use the Hugging Face Hub\n", "as a remote for your `git` repository, alongside GitHub or GitLab.\n", "This connection to the version control system can make for tricky integration,\n", "e.g. the need to create a new repository for each new model.\n", "\n", "By default, the demo is embedded inside Hugging Face,\n", "limiting your control over the look and feel.\n", "\n", "However, you can embed the demo in another website with\n", "[Web Components or IFrames](https://gradio.app/sharing_your_app/#embedding-with-web-components).\n", "You can also adapt the aesthetics and interactivity of the demo with\n", "[custom CSS and JS](https://gradio.app/custom_CSS_and_JS/).\n", "\n", "We will instead run the frontend server ourselves\n", "and provide a public URL\n", "without relying on Gradio's service." ] }, { "cell_type": "markdown", "metadata": { "id": "XWxKXSSG0yNX" }, "source": [ "Half of the work is already done for us:\n", "the `gradio` frontend is already listening on a port and IP address\n", "that is accessible locally\n", "(on `127.0.0.1` or `localhost`, as printed below)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ugupgc1bxQlH" }, "outputs": [], "source": [ "frontend_serverless_backend.local_url" ] }, { "cell_type": "markdown", "metadata": { "id": "GWcQa-ks1Ktn" }, "source": [ "So we can, for example, send `curl` requests locally,\n", "i.e. on the same machine as the frontend,\n", "and get responses." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z4JRaVjH0kPw" }, "outputs": [], "source": [ "# we send an improperly formatted request, because we just want to check for a response\n", "\n", "!curl -X POST {frontend_serverless_backend.local_url}api/predict" ] }, { "cell_type": "markdown", "metadata": { "id": "ZK4-tPGf32Hf" }, "source": [ "Running the same command on another machine will result in an error --\n", "`127.0.0.1` and `localhost` always mean \"on this machine\"." ] }, { "cell_type": "markdown", "metadata": { "id": "Eiwa6INa0PGe" }, "source": [ "So fundamentally,\n", "the goal is to take the frontend service\n", "running on an IP and port that is only accessible locally\n", "and make it accessible globally." ] }, { "cell_type": "markdown", "metadata": { "id": "Cuuj13Xk0M0Q" }, "source": [ "There's some tricky bits here --\n", "for example, you'll want to communicate using encryption,\n", "i.e. over HTTPS instead of HTTP --\n", "that make doing this entirely on your own\n", "a bit of a headache.\n", "\n", "To avoid these issues,\n", "we can once again use\n", "[`ngrok`](https://ngrok.com/),\n", "the service we used to provide access to our Label Studio instance\n", "in the data annotation lab.\n", "\n", "The free tier includes public URLs and secure communication with HTTPS.\n", "\n", "However, the URL changes each time you relaunch your service,\n", "e.g. after an outage or a version update.\n", "\n", "The paid tier allows for branded domains,\n", "simpler authentication with\n", "[OAuth](https://oauth.net/),\n", "and some basic scaling tools like load balancing.\n", "\n", "This is what we use for the official FSDL text recognizer at\n", "[fsdl-text-recognizer.ngrok.io](https://fsdl-text-recognizer.ngrok.io/)." ] }, { "cell_type": "markdown", "metadata": { "id": "IoKA_VUr4Gf2" }, "source": [ "To get started, let's\n", "set up our `ngrok` credentials." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3N2jkwdaLZAu" }, "outputs": [], "source": [ "import os\n", "import getpass\n", "\n", "from pyngrok import ngrok\n", "\n", "config_file = ngrok.conf.DEFAULT_NGROK_CONFIG_PATH\n", "config_file_exists = os.path.exists(config_file)\n", "config_file_contents = !cat {config_file}\n", "\n", "auth_token_found = config_file_exists \\\n", " and config_file_contents \\\n", " and \"authtoken\" in config_file_contents[0] \\\n", " and \": exit\" not in config_file_contents # state if interrupted\n", "\n", "if not auth_token_found:\n", " print(\"Enter your ngrok auth token, which can be copied from https://dashboard.ngrok.com/auth\")\n", " !ngrok authtoken {getpass.getpass()}" ] }, { "cell_type": "markdown", "metadata": { "id": "m3SaBJn14YA_" }, "source": [ "From there,\n", "it's as simple as pointing\n", "an `ngrok` tunnel\n", "at the port associated with your frontend.\n", "\n", "> For our purposes, ports are\n", "\"places you can listen for messages to your web service\".\n", "By separating ports,\n", "which are identifiers within a machine,\n", "from URLs/IPs,\n", "which are identifiers across machines,\n", "we can run multiple services on a single machine." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wURZiaA5LkeF" }, "outputs": [], "source": [ "TEXT_RECOGNIZER_PORT = frontend_serverless_backend.server_port\n", "\n", "https_tunnel = ngrok.connect(TEXT_RECOGNIZER_PORT, bind_tls=True)\n", "print(https_tunnel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Head to the printed `ngrok.io` URL from any device --\n", "e.g. a mobile phone --\n", "to check out your shiny new ML-powered application UI\n", "with serverless backend." ] }, { "cell_type": "markdown", "metadata": { "id": "XWYBGHLs5iwN" }, "source": [ "Running a web service out of a Jupyter notebook is not recommended.\n", "\n", "`gradio` and `ngrok`\n", "can be run from the command line.\n", "\n", "If you're running the lab locally,\n", "just define the `TEXT_RECOGNIZER_PORT`\n", "and `LAMBDA_URL` environment variables\n", "and then run\n", "\n", "```bash\n", "python app_gradio/app.py --model_url $LAMBDA_URL --model_port $TEXT_RECOGNIZER_PORT\n", "```\n", "\n", "in one terminal\n", "and, in a separate terminal,\n", "run\n", "```bash\n", "ngrok $TEXT_RECOGNIZER_PORT https\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "nycSygTy-PcQ" }, "source": [ "and navigate to the printed URL." ] }, { "cell_type": "markdown", "metadata": { "id": "oQCpzYzHRGfd" }, "source": [ "## Launching a server on a cloud instance" ] }, { "cell_type": "markdown", "metadata": { "id": "RKKnzQjmQPV8" }, "source": [ "We are almost, but not quite,\n", "to the point of a reasonably professional web service.\n", "\n", "The last missing piece is that our server is running\n", "either on Colab,\n", "which has short uptimes and is not intended for serving,\n", "or on our own personal machine,\n", "which is also likely a few\n", "[nines](https://en.wikipedia.org/wiki/High_availability#Percentage_calculation) short of an uptime SLA." ] }, { "cell_type": "markdown", "metadata": { "id": "IKOuYfpTQR-c" }, "source": [ "We want to instead run this on a dedicated server,\n", "and the simplest way to do so is to spin up a machine in a cloud provider.\n", "\n", "[Elastic Compute Cloud](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/concepts.html)\n", "(aka EC2)\n", "is the option in AWS,\n", "our chosen cloud provider." ] }, { "cell_type": "markdown", "metadata": { "id": "15NI6gI1746O" }, "source": [ "To get the server going on another machine,\n", "we'll need to `git clone` our library,\n", "`pip install` our `prod` requirements,\n", "and then finally run `ngrok` and `app_gradio/app.py`." ] }, { "cell_type": "markdown", "metadata": { "id": "faStq6aV-hci" }, "source": [ "We can make that process slightly easier\n", "by incorporating it into a `Dockerfile`\n", "and building a container image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_1i0M7hR-moU" }, "outputs": [], "source": [ "!cat app_gradio/Dockerfile" ] }, { "cell_type": "markdown", "metadata": { "id": "jskTeGs9AroE" }, "source": [ "We can then store the container image in a registry, like\n", "[Docker Hub](https://hub.docker.com/)\n", "or the container image registry built into our cloud provider, like AWS's\n", "[Elastic Container Registry](https://aws.amazon.com/ecr/).\n", "\n", "Then, setup just means pulling the image down onto the machine\n", "we want to run our server from and executing a `docker run` command." ] } ], "metadata": { "colab": { "collapsed_sections": [], "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 1 } ================================================ FILE: lab08/notebooks/lab08_monitoring.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "7yQQTA9IGDt8" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "MX9n-Zed8G_T" }, "source": [ "# Lab 08: Monitoring" ] }, { "cell_type": "markdown", "metadata": { "id": "tv8O0V0EV09z" }, "source": [ "## What You Will Learn\n", "\n", "- How to add user feedback and model monitoring to a Gradio-based app\n", "- How to analyze this logged information to uncover and debug model issues\n", "- Just how large the gap between benchmark data and data from users can be, and what to do about it" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "45D6GuSwvT7d" }, "outputs": [], "source": [ "lab_idx = 8\n", "\n", "\n", "if \"bootstrap\" not in locals() or bootstrap.run:\n", " # path management for Python\n", " pythonpath, = !echo $PYTHONPATH\n", " if \".\" not in pythonpath.split(\":\"):\n", " pythonpath = \".:\" + pythonpath\n", " %env PYTHONPATH={pythonpath}\n", " !echo $PYTHONPATH\n", "\n", " # get both Colab and local notebooks into the same state\n", " !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py\n", " import bootstrap\n", "\n", " %matplotlib inline\n", "\n", " # change into the lab directory\n", " bootstrap.change_to_lab_dir(lab_idx=lab_idx)\n", "\n", " bootstrap.run = False # change to True re-run setup\n", "\n", "!pwd\n", "%ls" ] }, { "cell_type": "markdown", "metadata": { "id": "cUdTJE54V09z" }, "source": [ "### Follow along with a video walkthrough on YouTube:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4J9hDxNsV09z" }, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "\n", "IFrame(src=\"https://fsdl.me/2022-lab-08-video-embed\", width=\"100%\", height=720)" ] }, { "cell_type": "markdown", "metadata": { "id": "Zvi49122ho0r" }, "source": [ "# Basic user feedback with `gradio`" ] }, { "cell_type": "markdown", "metadata": { "id": "56y2r9IYkY7A" }, "source": [ "On top of the basic health check and event logging\n", "necessary for any distributed system\n", "(provided for our application by\n", "[AWS CloudWatch](https://aws.amazon.com/cloudwatch/),\n", "which is collects logs from EC2 and Lambda instances),\n", "ML-powered applications need specialized monitoring solutions.\n", "\n", "In particular, we want to give users a way\n", "to report issues or indicate their level of satisfaction\n", "with the model.\n", "\n", "The UI-building framework we're using, `gradio`,\n", "comes with user feedback, under the name \"flagging\"." ] }, { "cell_type": "markdown", "metadata": { "id": "wXq4jcjCkNap" }, "source": [ "To see how this works, we first spin up our front end,\n", "pointed at the AWS Lambda backend,\n", "as in\n", "[the previous lab](https://fsdl.me/lab07-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rAZrYRnSiMER" }, "outputs": [], "source": [ "from app_gradio import app\n", "\n", "\n", "lambda_url = \"https://3akxma777p53w57mmdika3sflu0fvazm.lambda-url.us-west-1.on.aws/\"\n", "\n", "backend = app.PredictorBackend(url=lambda_url)" ] }, { "cell_type": "markdown", "metadata": { "id": "STXn1XaHkU42" }, "source": [ "And adding user feedback collection\n", "is as easy as passing `flagging=True`.\n", "\n", "> The `flagging` argument is here being given to\n", "code from the FSDL codebase, `app.make_frontend`,\n", "but you can just pass\n", "`flagging=True` directly\n", "to the `gradio.Interface` class.\n", "In between in our code,\n", "we have a bit of extra logic\n", "so that we can support\n", "multiple different storage backends for logging flagged data.\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "mxZQRklXV091" }, "source": [ "Run the cell below to create a frontend\n", "(accessible on a public Gradio URL and inside the notebook)\n", "and observe the new \"flagging\" buttons underneath the outputs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kgygx8d5ip9V" }, "outputs": [], "source": [ "frontend = app.make_frontend(fn=backend.run, flagging=True)\n", "frontend.launch(share=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "zV2tu8HTk242" }, "source": [ "Click one of the buttons to trigger flagging.\n", "\n", "It doesn't need to be a legitimate issue with the model's outputs.\n", "\n", "Instead of just submitting one of the example images,\n", "you might additionally use the image editor\n", "(pencil button on uploaded images)\n", "to crop it." ] }, { "cell_type": "markdown", "metadata": { "id": "gJV79PDIk-4S" }, "source": [ "Flagged data is stored on the server's local filesystem,\n", "by default in the `flagged/` directory\n", "as a `.csv` file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RbCcCxvHi2jh" }, "outputs": [], "source": [ "!ls flagged" ] }, { "cell_type": "markdown", "metadata": { "id": "Koh1SP9NlA6y" }, "source": [ "We can load the `.csv` with `pandas`,\n", "the Python library for handling tabular data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OJCnIsfEjC05" }, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import pandas as pd\n", "\n", "\n", "log_path = Path(\"flagged\") / \"log.csv\"\n", "\n", "flagged_df = None\n", "if log_path.exists():\n", " flagged_df = pd.read_csv(log_path, quotechar=\"'\") # quoting can be painful for natural text data\n", " flagged_df = flagged_df.dropna(subset=[\"Handwritten Text\"]) # drop any flags without an image\n", "\n", "flagged_df" ] }, { "cell_type": "markdown", "metadata": { "id": "KZieT-FgldKa" }, "source": [ "Notice that richer data, like images, is stored with references --\n", "here, the names of local files.\n", "\n", "This is a common pattern:\n", "binary data doesn't go in the database,\n", "only pointers to binary data.\n", "\n", "We can then read the data back to analyze our model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gWG3T3Qql_99" }, "outputs": [], "source": [ "from IPython.display import display\n", "\n", "from text_recognizer.util import read_image_pil\n", "\n", "\n", "if flagged_df is not None:\n", " row = flagged_df.iloc[-1]\n", " print(row[\"output\"])\n", " display(read_image_pil(Path(\"flagged\") / row[\"Handwritten Text\"]))" ] }, { "cell_type": "markdown", "metadata": { "id": "0gIpfRMFl9_D" }, "source": [ "We encourage you to play around with the model for a bit,\n", "uploading your own images.\n", "\n", "This is an important step in understanding your model\n", "and your domain --\n", "especially when you're familiar with the data types involved.\n", "\n", "But even when you are,\n", "we expect you'll quickly find\n", "that you run out of ideas\n", "for different ways to probe your model.\n", "\n", "To really learn more about your model,\n", "you'll need some actual users.\n", "\n", "In small projects,\n", "these can be other team members who are less enmeshed\n", "in the details of model development and data munging.\n", "\n", "But to create something that can appeal to a broader set of users,\n", "you'll want to collect feedback from your potential userbase." ] }, { "cell_type": "markdown", "metadata": { "id": "RHArpXNyRtg7" }, "source": [ "# Debugging production models with `gantry`" ] }, { "cell_type": "markdown", "metadata": { "id": "hbGCYG0BmvdE" }, "source": [ "Unfortunately, this aspect of model development\n", "is particularly challenging to replicate in\n", "a course setting, especially a MOOC --\n", "where do these users come from?\n", "\n", "As part of the 2022 edition of the course, we've\n", "[been running a text recognizer application](https://fsdl-text-recognizer.ngrok.io)\n", "and collecting user feedback on it.\n", "\n", "Rather than saving user feedback data locally,\n", "as with the CSV logger above,\n", "we've been sending that data to\n", "[Gantry](https://gantry.io/),\n", "a model monitoring and continual learning tool.\n", "\n", "That's because local logging is a very bad idea:\n", "as logs grow, the storage needs and read/write time grow,\n", "which unduly burdens the frontend server.\n", "\n", "The `gradio` library supports logging of user-flagged data\n", "to arbitrary backends via\n", "`FlaggingCallback`s.\n", "\n", "So there's some new elements to the codebase:\n", "most importantly here, a `GantryImageToTextLogger`\n", "that inherits from `gradio.FlaggingCallback`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pptT76DWmlB0" }, "outputs": [], "source": [ "from app_gradio import flagging\n", "\n", "\n", "print(flagging.GantryImageToTextLogger.__init__.__doc__)" ] }, { "cell_type": "markdown", "metadata": { "id": "-3HevRM2YkbZ" }, "source": [ "If we add this `Callback` to our setup --\n", "and add a Gantry API key to our environment --\n", "then we can start sending data to Gantry's service." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UHnIV0e_a9o6" }, "outputs": [], "source": [ "app.make_frontend??" ] }, { "cell_type": "markdown", "metadata": { "id": "jJcfaWNpRzJF" }, "source": [ "The short version of how the logging works:\n", "we upload flagged images to S3 for storage (`GantryImageToTextLogger._to_s3`)\n", "and send the URL to Gantry along with the outputs (`GantryImageToTextLogger._to_gantry`)." ] }, { "cell_type": "markdown", "metadata": { "id": "uviSZDTma1RT" }, "source": [ "Below, we'll download that data\n", "and look through it in the notebook,\n", "using typical Python data analysis tools,\n", "like `pandas` and `seaborn`.\n", "\n", "By analogy to\n", "[EDA](https://en.wikipedia.org/wiki/Exploratory_data_analysis),\n", "consider this an \"exploratory model analysis\"." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LFxypmESXESL" }, "outputs": [], "source": [ "import gantry.query as gq\n", "\n", "\n", "read_only_key = \"VpPfHPDSk9e9KKAgbiHBh7mqF_8\"\n", "gq.init(api_key=read_only_key)\n", "\n", "gdf = gq.query( # we query Gantry's service with the following parameters:\n", " application=\"fsdl-text-recognizer\", # which tracked application should we draw from?\n", " # what time period should we pull data from? here, the first two months the app was up\n", " start_time=\"2022-07-01T07:00:00.000Z\",\n", " end_time=\"2022-09-01T06:59:00.000Z\",\n", ")\n", "\n", "raw_df = gdf.fetch()\n", "df = raw_df.dropna(axis=\"columns\", how=\"all\") # remove any irrelevant columns\n", "df = df[df[\"tags.env\"] == \"dev\"] # filter down to info logged from the development environment\n", "print(\"number of rows:\", len(df))\n", "df = df.drop_duplicates(keep=\"first\", subset=\"inputs.image\") # remove repeated reports, eg of example images\n", "print(\"number of unique rows:\", len(df))\n", "\n", "print(\"\\ncolumns:\")\n", "df.columns" ] }, { "cell_type": "markdown", "metadata": { "id": "bN6YNmnCV094" }, "source": [ "We'll walk through what each of these columns means,\n", "but the three most important are the ones we logged directly from the application:\n", "`flag`s, `input.image`s, and `output_text`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c8SEwiAXV094" }, "outputs": [], "source": [ "main_columns = [column for column in df.columns if \"(\" not in column] # derived columns have a \"function call\" in the name\n", "main_columns" ] }, { "cell_type": "markdown", "metadata": { "id": "i8HfH-BIV094" }, "source": [ "If you're interested in playing\n", "around with the data yourself\n", "in Gantry's UI,\n", "as we do in the\n", "[video walkthrough for the lab](https://fsdl.me/2022-lab-08-video),\n", "you'll need a Gantry account.\n", "\n", "Gantry is currently in closed beta.\n", "Unlike training or experiment management,\n", "model monitoring and continual learning\n", "is at the frontier of applied ML,\n", "so tooling is just starting to roll out.\n", "\n", "FSDL students are invited to this beta and\n", "[can create a \"read-only\" account here](https://gantry.io/fsdl-signup)\n", "so they can view the data in the UI\n", "and explore it themselves.\n", "\n", "As an early startup,\n", "Gantry is very interested in feedback\n", "from practitioners!\n", "So if you do try out the Gantry UI,\n", "send any impressions, bug reports, or ideas to\n", "`support@gantry.io`\n", "\n", "This is also a chance for you\n", "to influence the development\n", "of a new tool that could one day\n", "end up at the center of continual learning\n", "workflows --\n", "as when\n", "[FSDL students in spring 2019 got a chance to be early users of W&B](https://www.youtube.com/watch?t=1468&v=Eiz1zcqrqw0&feature=youtu.be&ab_channel=FullStackDeepLearning)." ] }, { "cell_type": "markdown", "metadata": { "id": "RmTFHvxHi4el" }, "source": [ "## Basic stats and behavioral monitoring" ] }, { "cell_type": "markdown", "metadata": { "id": "hYSQ0r7eV094" }, "source": [ "We start by just getting some basic statistics.\n", "\n", "For example, we can get descriptive statistics for\n", "the information we've logged." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fb3BMn7gfQRI" }, "outputs": [], "source": [ "df[\"feedback.flag\"].describe()" ] }, { "cell_type": "markdown", "metadata": { "id": "T9OseYhc1Q8i" }, "source": [ "Note that the format we're working with is the `pandas.DataFrame` --\n", "a standard format for tables in Python.\n", "\n", "`pandas` can be\n", "[very tricky](https://github.com/chiphuyen/just-pandas-things).\n", "\n", "It's not so bad when doing exploratory analysis like this,\n", "but take care when using it in production settings!\n", "\n", "If you'd like to learn more `pandas`,\n", "[Brandon Rhodes's `pandas` tutorial from PyCon 2015](https://www.youtube.com/watch?v=5JnMutdy6Fw&ab_channel=PyCon2015)\n", "is still one of the best introductions,\n", "even though it's nearly a decade old." ] }, { "cell_type": "markdown", "metadata": { "id": "eG15SMkgV095" }, "source": [ "`pandas` objects support sampling with `.sample`,\n", "which is useful for quick \"spot-checking\" of data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FZ5BRRqjc1Of" }, "outputs": [], "source": [ "df[\"feedback.flag\"].sample(10)" ] }, { "cell_type": "markdown", "metadata": { "id": "w3rZaYwSzu-D" }, "source": [ "Unlike many other kinds of applications,\n", "toxic and offensive behavior is\n", "one of the most critical potential issues with\n", "many ML models,\n", "from\n", "[generative models like GPT-3](https://www.middlebury.edu/institute/sites/www.middlebury.edu.institute/files/2020-09/gpt3-article.pdf)\n", "to even humble\n", "[image labeling models](https://archive.nytimes.com/bits.blogs.nytimes.com/2015/07/01/google-photos-mistakenly-labels-black-people-gorillas/).\n", "\n", "So ML models, especially when newly deployed\n", "or when encountering new user bases,\n", "need careful supervision." ] }, { "cell_type": "markdown", "metadata": { "id": "-CbdSz0hzze7" }, "source": [ "We use a\n", "[Gantry tool called Projections](https://docs.gantry.io/en/stable/guides/projections.html)\n", "to apply the NLP models from the\n", "[`detoxify` suite](https://github.com/unitaryai/detoxify),\n", "which score text for features like obscenity and identity attacks,\n", "to our model's outputs." ] }, { "cell_type": "markdown", "metadata": { "id": "1Z4lsgRcpQql" }, "source": [ "To get a quick plot of the resulting values,\n", "we can use the `pandas` built-in interface\n", "to `matplotlib`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9UbBg947fAsh" }, "outputs": [], "source": [ "df.plot(y=\"detoxify.obscene(outputs.output_text)\", kind=\"hist\");" ] }, { "cell_type": "markdown", "metadata": { "id": "qxiIXGf0pVd5" }, "source": [ "Without context, this chart isn't super useful --\n", "is a score of `obscene=0.12` bad?\n", "\n", "We need a baseline!" ] }, { "cell_type": "markdown", "metadata": { "id": "UbOeOkzQgBDE" }, "source": [ "Once the model is stable in production,\n", "we can compare the values across time --\n", "grouping or filtering production data by timestamp.\n", "\n", "Here, for this first version of the model,\n", "we compare the results here with the results on the test data,\n", "which was also ingested with `gantry`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ooa-Al48f_au" }, "outputs": [], "source": [ "test_df = raw_df.dropna(axis=\"columns\", how=\"all\") # remove any irrelevant columns\n", "test_df = test_df[test_df[\"tags.env\"] == \"test\"] # filter down to info logged from the test environment\n", "\n", "test_df.sample(10) # show a sample" ] }, { "cell_type": "markdown", "metadata": { "id": "TssF7sSX1Q8k" }, "source": [ "To compare the two `DataFrame`s,\n", "we `concat`enate them together\n", "and add in some metadata\n", "identifying where the observations came from.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oXWqfOdfgi4o" }, "outputs": [], "source": [ "test_df[\"environment\"] = \"test\"\n", "df[\"environment\"] = \"prod\"\n", "\n", "comparison_df = pd.concat([df, test_df])" ] }, { "cell_type": "markdown", "metadata": { "id": "5fp9gAX_V09_" }, "source": [ "From there, we can use grouping to calculate statistics of interest:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NIGBxyZIV09_" }, "outputs": [], "source": [ "stats = comparison_df.groupby(\"environment\").describe()\n", "\n", "stats[\"detoxify.obscene(outputs.output_text)\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "2G2tVhhY1Q8k" }, "source": [ "These descriptive statistics are helpful,\n", "but as with our simple plot above,\n", "we want to _look_ at the data.\n", "\n", "Exploratory data analysis is typically very visual --\n", "the goal is to find phenomena so obvious\n", "that statistical testing is an afterthought --\n", "and so is exploratory model analysis.\n", "\n", "`matplotlib` is based on plotting arrays,\n", "rather than `DataFrame`s or other tabular data,\n", "so it's not a great fit on its own here,\n", "unless we want to tolerate a lot of boilerplate.\n", "\n", "`pandas` has basic built-in plotting\n", "that interfaces with `matplotlib`,\n", "but it's not that ergonomic for comparisons or flexible\n", "without just dropping back to matplotlib.\n", "\n", "There are a number of other Python plotting libraries,\n", "many with an emphasis on share-ability and interaction\n", "([Vega-Altair](https://altair-viz.github.io/),\n", "[`bokeh`](http://bokeh.org/),\n", "and\n", "[Plotly](https://plotly.com/),\n", "to name a few)\n", "and others with an emphasis on usability\n", "(e.g. [`ggplot`](https://realpython.com/ggplot-python/)).\n", "\n", "The one that we like for in-notebook analysis\n", "that balances ease of use\n", "on tabular data with flexibility is\n", "[`seaborn`](https://seaborn.pydata.org/)." ] }, { "cell_type": "markdown", "metadata": { "id": "7nZV8uoG1Q8k" }, "source": [ "Comparing the distributions of the `detoxify.obscene` metric\n", "is a single function call:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WnGxCz1f1Q8k" }, "outputs": [], "source": [ "import seaborn as sns\n", "\n", "\n", "sns.displot( # plot the dis-tribution\n", " data=comparison_df, # of data from this df\n", " # specifically, this column, along the x-axis\n", " x=\"detoxify.obscene(outputs.output_text)\",\n", " # and split it up (in color/hue) by this column\n", " hue=\"environment\"\n", ");" ] }, { "cell_type": "markdown", "metadata": { "id": "jO6FuRCQV0-A" }, "source": [ "We can quickly see that the obscenity scores according to `detoxify`\n", "are generally lower in our `prod`uction environment,\n", "so we don't have a reason to suspect\n", "our model is behaving too badly in production\n", "-- though see the exercises for more on this!\n", "\n", "We can see the same thing\n", "without having to write query, cleaning, and plotting code\n", "[in the Gantry UI here](https://app.gantry.io/applications/fsdl-text-recognizer/distribution?view=2022-class&compare=test-ingest) --\n", "note that viewing the dashboard requires a Gantry account,\n", "which you can sign up for\n", "[here](https://gantry.io/fsdl-signup)." ] }, { "cell_type": "markdown", "metadata": { "id": "iKZ0l2MCjlDn" }, "source": [ "## Debugging the Text Recognizer" ] }, { "cell_type": "markdown", "metadata": { "id": "ovp8fZ1GpUet" }, "source": [ "In our application,\n", "we don't have user corrections or labels from annotators,\n", "so we can't calculate an accuracy, a loss, or a character error rate.\n", "\n", "We instead look for signals that are correlated with\n", "those values.\n", "\n", "This approach has limits\n", "(see, e.g. the analysis in the\n", "[MLDeMon paper](https://arxiv.org/abs/2104.13621))\n", "and setting alerts or test failures on things that are only correlated with,\n", "rather than directly caused by, poor performance is a bad idea.\n", "\n", "But it's very useful to have this information logged\n", "to catch large errors at a glance\n", "or to provide tools for slicing, filtering, and grouping data\n", "while doing exploratory model analysis or debugging." ] }, { "cell_type": "markdown", "metadata": { "id": "0YauDrY51Q8l" }, "source": [ "We can also compute these signals with Gantry Projections.\n", "\n", "Low entropy (e.g. repetition) is a failure mode of language models,\n", "as is excessively high entropy (e.g. uniformly random text).\n", "\n", "We can review the output text entropy distributions in\n", "production and during testing\n", "by plotting them against one another\n", "(here or\n", "[in the Gantry UI](https://app.gantry.io/applications/fsdl-text-recognizer/distribution?view=2022-class&compare=test-ingest))." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "czepR9o7l2FO" }, "outputs": [], "source": [ "sns.displot(\n", " data=comparison_df,\n", " x=\"text_stats.basics.entropy(outputs.output_text)\",\n", " hue=\"environment\"\n", ");" ] }, { "cell_type": "markdown", "metadata": { "id": "8LiFvkoR1Q8l" }, "source": [ "It appears there are more low-entropy strings in the model's outputs in production.\n", "\n", "With models that operate on human-relevant data,\n", "like text and images,\n", "it's important to look at the raw data,\n", "not just projections.\n", "\n", "Let's take a look at a sample of outputs from the model running on test data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FQ9kTz2ZmOwR" }, "outputs": [], "source": [ "test_df[\"outputs.output_text\"].sample(10)" ] }, { "cell_type": "markdown", "metadata": { "id": "BpZ_35uD1Q8l" }, "source": [ "The results are not incredible, but they are recognizably \"English with typos\"." ] }, { "cell_type": "markdown", "metadata": { "id": "NVlj3vYf1Q8l" }, "source": [ "Let's look specifically at low entropy examples from production\n", "(we can also view this\n", "[filtered data in the Gantry UI](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=2022-class-low-entropy&compare=test-ingest))." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p0dkx1VzoJ9C" }, "outputs": [], "source": [ "df.loc[df[\"text_stats.basics.entropy(outputs.output_text)\"] < 5][\"outputs.output_text\"].sample(10)" ] }, { "cell_type": "markdown", "metadata": { "id": "iMmcPuynV0-C" }, "source": [ "Yikes! Lots of repetitive gibberish." ] }, { "cell_type": "markdown", "metadata": { "id": "stStBoCZ1Q8m" }, "source": [ "Knowing the outputs are bad,\n", "there are two culprits:\n", "the input-output mapping (aka the model)\n", "or the inputs." ] }, { "cell_type": "markdown", "metadata": { "id": "nFaGYnjcmKf6" }, "source": [ "We ran the same model in a similar environment\n", "to get those outputs,\n", "so it's most likely due to some difference in the inputs.\n", "\n", "Let's check them!\n", "\n", "We added Gantry Projections to look at the distribution of pixel values as well." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uSwnexFRlaIV" }, "outputs": [], "source": [ "sns.displot(\n", " data=comparison_df,\n", " x=\"image.greyscale_image_mean(inputs.image)\",\n", " hue=\"environment\"\n", ");" ] }, { "cell_type": "markdown", "metadata": { "id": "iqkWkM45yMgV" }, "source": [ "There's a huge difference in mean pixel values --\n", "almost all images have mean intensities that are very dark in the testing environment,\n", "but we see both dark and light images in production.\n", "\n", "Reviewing the\n", "[raw data in Gantry](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=2022-class-low-entropy&compare=test-ingest)\n", "confirms that we are getting images with very different brightnesses in production\n", "and whiffing the predictions\n", "-- along with images that reveal a number of other interesting failure modes." ] }, { "cell_type": "markdown", "metadata": { "id": "X5uWeR6n1Q8m" }, "source": [ "To take a look locally,\n", "we'll need to pull the images down from S3,\n", "where they are stored." ] }, { "cell_type": "markdown", "metadata": { "id": "NbNMlevz1Q8m" }, "source": [ "The cell below defines a quick utility for\n", "reading from S3 without authentication.\n", "\n", "It is based on the `smart_open` and `boto3` libraries,\n", "which we briefly saw in the\n", "[model deployment lab](https://fsdl.me/lab07-colab)\n", "and the\n", "[data annotation lab](https://fsdl.me/lab06-colab)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-FNIm0MOovtu" }, "outputs": [], "source": [ "import boto3\n", "from botocore import UNSIGNED\n", "from botocore.config import Config\n", "import smart_open\n", "\n", "from text_recognizer.util import read_image_pil_file\n", "\n", "# spin up a client for communicating with s3 without authenticating (\"UNSIGNED\" activity)\n", "s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))\n", "unsigned_params = {\"client\": s3}\n", "\n", "def read_image_unsigned(image_uri, grayscale=False):\n", " with smart_open.open(image_uri, \"rb\", transport_params=unsigned_params) as image_file:\n", " return read_image_pil_file(image_file, grayscale)" ] }, { "cell_type": "markdown", "metadata": { "id": "SxBpmPYrV0-F" }, "source": [ "Run the cell below to repeatedly sample a random input/output pair\n", "flagged in production." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xy90rzcWobuk" }, "outputs": [], "source": [ "row = df.sample().iloc[0]\n", "print(\"image url:\", row[\"inputs.image\"])\n", "print(\"prediction:\", row[\"outputs.output_text\"])\n", "read_image_unsigned(row[\"inputs.image\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "oFdT2W2xtOGx" }, "source": [ "### Take-aways for developing models\n", "\n", "The most immediate take-away from reviewing just a few examples is that\n", "user data is way more heterogeneous than train/val/test data!\n", "\n", "This a\n", "[fairly](https://browsee.io/blog/a-guide-to-session-replays-for-product-managers/)\n", "[universal](https://medium.com/@beasles/edge-case-responsive-design-9b610138ddbd)\n", "[finding](https://quoteinvestigator.com/2021/05/04/no-plan/).\n", "\n", "Let's also consider some specific failure modes in our case\n", "and how we might resolve them:\n", "\n", "- Failure mode: Users mostly provide images with dark text on light background, but we train on dark background.\n", " - Resolution: We could check image brightness and flip if needed,\n", " but this feels like a cop-out -- most text is dark on a light background!\n", " - Resolution: We add image brightness inversion to our train-time augmentations.\n", "- Failure mode: Users expect our \"handwritten text recognition\" tool to work with printed and digital text.\n", " - Resolution: We could try better sign-posting and user education,\n", " but this is also something of a cop-out.\n", " Users expect the tool to work on all text,\n", " so we shouldn't violate that expectation.\n", " - Resolution: We synthesize digital text data --\n", " text rendering is a feature of just about any mature programming language.\n", "- Failure mode: Users provide text on heterogeneous backgrounds\n", " - Resolution: We collect or synthesize more heterogeneous data,\n", " e.g. placing text (with or without background coloring)\n", " on top of random image backgrounds.\n", "- Failure mode: Users provide text with characters and symbols outside of our dictionary.\n", " - Resolution: We can expand the model outputs and collect more heterogeneous data\n", "- Failure mode: Users provide images with multiple blocks of text\n", " - Resolution: We develop an architecture/task definition that can handle multiple regions.\n", " We'll need to collect and/or synthesize data to support" ] }, { "cell_type": "markdown", "metadata": { "id": "9rQH6zI8u7WN" }, "source": [ "Notice: these are almost entirely changes to data,\n", "and most of them involve collecting more or synthesizing it.\n", "\n", "This is very much typical!\n", "\n", "Data drives improvements to models,\n", "[even at scale](https://www.lesswrong.com/posts/6Fpvch8RR29qLEWNH/chinchilla-s-wild-implications)." ] }, { "cell_type": "markdown", "metadata": { "id": "2P5MrIW5V0-F" }, "source": [ "### Take-aways for exploratory model analysis" ] }, { "cell_type": "markdown", "metadata": { "id": "mfMf1wwR1Q8n" }, "source": [ "Notice that we had to write a lot of code,\n", "which was developed and which we ran in a\n", "tight interactive loop.\n", "\n", "This type of code is very hard to turn into scripts --\n", "how do you trigger an alert on a plot? --\n", "which makes it brittle and hard to version and share.\n", "\n", "It's also based on possibly very large-scale data artifacts.\n", "\n", "The right tool for this job is a UI\n", "on top of a database.\n", "\n", "In the\n", "[video walkthrough for this lab](https://fsdl.me/2022-lab-08-video),\n", "we do the effectively the same analysis,\n", "but inside Gantry,\n", "which makes the process more fluid.\n", "\n", "Gantry is still in closed beta,\n", "but if you're interested in applying it to your own applications, you can\n", "[join the waitlist](https://gantry.io/waitlist/)." ] }, { "cell_type": "markdown", "metadata": { "id": "M73gui0XhgCl" }, "source": [ "# Exercises" ] }, { "cell_type": "markdown", "metadata": { "id": "mWWrmGiThhMw" }, "source": [ "### 🌟 Examine the test data strings, both output and ground truth." ] }, { "cell_type": "markdown", "metadata": { "id": "km0nv0Mghmd_" }, "source": [ "We compared our production obscenity metric to the test-time values of that same metric\n", "and determined that we had not gotten worse,\n", "so things were fine.\n", "\n", "But what if the test-time baseline is bad?\n", "\n", "Review the raw test ground truth data\n", "[here](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=test-ingest),\n", "if you\n", "[signed up a Gantry account](https://gantry.io/fsdl-signup),\n", "or by looking at the contents of `test_df` above.\n", "\n", "Sort by `detoxify.identity_attack(feedback.ground_truth_string)`\n", "or filter to only high values of that metric.\n", "\n", "Review the example `feedback.ground_truth_string` texts and consider:\n", "is this the subset of English\n", "we want the model to be training on?\n", "what objections might be raised to the contents?\n", "\n", "You might also look for cases where the `detoxify` models misunderstood meaning --\n", "e.g. an innocuous use of a word that's often used objectionably." ] }, { "cell_type": "markdown", "metadata": { "id": "1Q6mWRwS1Q8t" }, "source": [ "### 🌟🌟 Start building \"regression testing suites\" by doing error analysis on these examples." ] }, { "cell_type": "markdown", "metadata": { "id": "jfsCnjCg1Q8t" }, "source": [ "Do this by going through feedback data one image/text pair at a time --\n", "[in Gantry](https://app.gantry.io/applications/fsdl-text-recognizer/data?view=2022-class-low-entrop)\n", "or inside this notebook.\n", "\n", "Start by just taking notes on each example\n", "(anywhere -- Google Sheets/Excel/Notion, or just a sheet of paper).\n", "\n", "The primary question you should ask is:\n", "how does this example differ from the data shown in training?\n", "\n", "Check\n", "[this W&B Artifact page](https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/artifacts/run_table/run-1vrnrd8p-trainpredictions/v194/files/train/predictions.table.json#f5854c9c18f6c24a4e99)\n", "to see what training data\n", "(including augmentation)\n", "looks like.\n", "\n", "Once you have some notes,\n", "try and formalize them into a small number of \"failure modes\" --\n", "you can choose to align them with the failure modes described in the section\n", "on take-aways for model development or not.\n", "\n", "If you want to finish the loop,\n", "you might set up Label Studio, as in\n", "[the data annotation lab](https://fsdl.me/lab06-colab).\n", "An annotator should add at least a\n", "\"label\" that gives the type of issue\n", "and perhaps also add a text annotation\n", "while they are at it." ] } ], "metadata": { "colab": { "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: lab08/tasks/lint.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false # apply automatic formatting echo "black" pre-commit run black || FAILURE=true # check for python code style violations, see .flake8 for details echo "flake8" pre-commit run flake8 || FAILURE=true # check for shell scripting style violations and common bugs echo "shellcheck" pre-commit run shellcheck || FAILURE=true # check python types echo "mypy" pre-commit run mypy || FAILURE=true if [ "$FAILURE" = true ]; then echo "Linting failed" exit 1 fi echo "Linting passed" exit 0 ================================================ FILE: lab08/text_recognizer/__init__.py ================================================ """Modules for creating and running a text recognizer.""" ================================================ FILE: lab08/text_recognizer/callbacks/__init__.py ================================================ from .model import ModelSizeLogger from .optim import LearningRateMonitor from . import imtotext from .imtotext import ImageToTextTableLogger as ImageToTextLogger ================================================ FILE: lab08/text_recognizer/callbacks/imtotext.py ================================================ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only try: import wandb has_wandb = True except ImportError: has_wandb = False from .util import check_and_warn class ImageToTextTableLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.on_train: if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_table", "image-to-text table"): return else: self._log_image_text_table(trainer, output, batch, "validation/predictions") def _log_image_text_table(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = xs[:mx], gt_strs[:mx], pred_strs[:mx] xs = [wandb.Image(x) for x in xs] rows = zip(*[xs, gt_strs, pred_strs]) columns = ["input_image", "ground_truth_string", "predicted_string"] trainer.logger.log_table(key=key, columns=columns, data=list(rows)) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) class ImageToTextCaptionLogger(pl.Callback): """Logs the inputs and outputs of an image-to-text model to Weights & Biases.""" def __init__(self, max_images_to_log=32, on_train=True): super().__init__() self.max_images_to_log = min(max(max_images_to_log, 1), 32) self.on_train = on_train self._required_keys = ["gt_strs", "pred_strs"] @rank_zero_only def on_train_batch_end(self, trainer, module, output, batch, batch_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "train/predictions") @rank_zero_only def on_validation_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "validation/predictions") @rank_zero_only def on_test_batch_end(self, trainer, module, output, batch, batch_idx, dataloader_idx): if self.has_metrics(output): if check_and_warn(trainer.logger, "log_image", "image-to-text"): return else: self._log_image_text_caption(trainer, output, batch, "test/predictions") def _log_image_text_caption(self, trainer, output, batch, key): xs, _ = batch gt_strs = output["gt_strs"] pred_strs = output["pred_strs"] mx = self.max_images_to_log xs, gt_strs, pred_strs = list(xs[:mx]), gt_strs[:mx], pred_strs[:mx] trainer.logger.log_image(key, xs, caption=pred_strs) def has_metrics(self, output): return all(key in output.keys() for key in self._required_keys) ================================================ FILE: lab08/text_recognizer/callbacks/model.py ================================================ import os from pathlib import Path import tempfile import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_only import torch from .util import check_and_warn, logging try: import torchviz has_torchviz = True except ImportError: has_torchviz = False class ModelSizeLogger(pl.Callback): """Logs information about model size (in parameters and on disk).""" def __init__(self, print_size=True): super().__init__() self.print_size = print_size @rank_zero_only def on_fit_start(self, trainer, module): self._run(trainer, module) def _run(self, trainer, module): metrics = {} metrics["mb_disk"] = self.get_model_disksize(module) metrics["nparams"] = count_params(module) if self.print_size: print(f"Model State Dict Disk Size: {round(metrics['mb_disk'], 2)} MB") metrics = {f"size/{key}": value for key, value in metrics.items()} trainer.logger.log_metrics(metrics, step=-1) @staticmethod def get_model_disksize(module): """Determine the model's size on disk by saving it to disk.""" with tempfile.NamedTemporaryFile() as f: torch.save(module.state_dict(), f) size_mb = os.path.getsize(f.name) / 1e6 return size_mb class GraphLogger(pl.Callback): """Logs a compute graph as an image.""" def __init__(self, output_key="logits"): super().__init__() self.graph_logged = False self.output_key = output_key if not has_torchviz: raise ImportError("GraphLogCallback requires torchviz." "") @rank_zero_only def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx): if not self.graph_logged: try: outputs = outputs[0][0]["extra"] self.log_graph(trainer, module, outputs[self.output_key]) except KeyError: logging.warning(f"Unable to log graph: outputs not found at key {self.output_key}") self.graph_logged = True @staticmethod def log_graph(trainer, module, outputs): if check_and_warn(trainer.logger, "log_image", "graph"): return params_dict = dict(list(module.named_parameters())) graph = torchviz.make_dot(outputs, params=params_dict) graph.format = "png" fname = Path(trainer.logger.experiment.dir) / "graph" graph.render(fname) fname = str(fname.with_suffix("." + graph.format)) trainer.logger.log_image(key="graph", images=[fname]) def count_params(module): """Counts the number of parameters in a Torch Module.""" return sum(p.numel() for p in module.parameters()) ================================================ FILE: lab08/text_recognizer/callbacks/optim.py ================================================ import pytorch_lightning as pl KEY = "optimizer" class LearningRateMonitor(pl.callbacks.LearningRateMonitor): """Extends Lightning's LearningRateMonitor with a prefix. Logs the learning rate during training. See the docs for pl.callbacks.LearningRateMonitor for details. """ def _add_prefix(self, *args, **kwargs) -> str: return f"{KEY}/" + super()._add_prefix(*args, **kwargs) ================================================ FILE: lab08/text_recognizer/callbacks/util.py ================================================ import logging logging.basicConfig(level=logging.WARNING) def check_and_warn(logger, attribute, feature): if not hasattr(logger, attribute): warn_no_attribute(feature, attribute) return True def warn_no_attribute(blocked_feature, missing_attribute): logging.warning(f"Unable to log {blocked_feature}: logger does not have attribute {missing_attribute}.") ================================================ FILE: lab08/text_recognizer/data/__init__.py ================================================ """Module containing submodules for each dataset. Each dataset is defined as a class in that submodule. The datasets should have a .config method that returns any configuration information needed by the model. Most datasets define their constants in a submodule of the metadata module that is parallel to this one in the hierarchy. """ from .util import BaseDataset from .base_data_module import BaseDataModule from .mnist import MNIST from .emnist import EMNIST from .emnist_lines import EMNISTLines from .iam_paragraphs import IAMParagraphs from .iam_lines import IAMLines from .fake_images import FakeImageData from .iam_synthetic_paragraphs import IAMSyntheticParagraphs from .iam_original_and_synthetic_paragraphs import IAMOriginalAndSyntheticParagraphs ================================================ FILE: lab08/text_recognizer/data/base_data_module.py ================================================ """Base DataModule class.""" import argparse import os from pathlib import Path from typing import Collection, Dict, Optional, Tuple, Union import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader from text_recognizer import util from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.shared as metadata def load_and_print_info(data_module_class) -> None: """Load EMNISTLines and print info.""" parser = argparse.ArgumentParser() data_module_class.add_to_argparse(parser) args = parser.parse_args() dataset = data_module_class(args) dataset.prepare_data() dataset.setup() print(dataset) def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path: dl_dirname.mkdir(parents=True, exist_ok=True) filename = dl_dirname / metadata["filename"] if filename.exists(): return filename print(f"Downloading raw dataset from {metadata['url']} to {filename}...") util.download_url(metadata["url"], filename) print("Computing SHA-256...") sha256 = util.compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.") return filename BATCH_SIZE = 128 NUM_AVAIL_CPUS = len(os.sched_getaffinity(0)) NUM_AVAIL_GPUS = torch.cuda.device_count() # sensible multiprocessing defaults: at most one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS # but in distributed data parallel mode, we launch a training on each GPU, so must divide out to keep total at one worker per CPU DEFAULT_NUM_WORKERS = NUM_AVAIL_CPUS // NUM_AVAIL_GPUS if NUM_AVAIL_GPUS else DEFAULT_NUM_WORKERS class BaseDataModule(pl.LightningDataModule): """Base for all of our LightningDataModules. Learn more at about LDMs at https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html """ def __init__(self, args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.batch_size = self.args.get("batch_size", BATCH_SIZE) self.num_workers = self.args.get("num_workers", DEFAULT_NUM_WORKERS) self.on_gpu = isinstance(self.args.get("gpus", None), (str, int)) # Make sure to set the variables below in subclasses self.input_dims: Tuple[int, ...] self.output_dims: Tuple[int, ...] self.mapping: Collection self.data_train: Union[BaseDataset, ConcatDataset] self.data_val: Union[BaseDataset, ConcatDataset] self.data_test: Union[BaseDataset, ConcatDataset] @classmethod def data_dirname(cls): return metadata.DATA_DIRNAME @staticmethod def add_to_argparse(parser): parser.add_argument( "--batch_size", type=int, default=BATCH_SIZE, help=f"Number of examples to operate on per forward step. Default is {BATCH_SIZE}.", ) parser.add_argument( "--num_workers", type=int, default=DEFAULT_NUM_WORKERS, help=f"Number of additional processes to load data. Default is {DEFAULT_NUM_WORKERS}.", ) return parser def config(self): """Return important settings of the dataset, which will be passed to instantiate models.""" return {"input_dims": self.input_dims, "output_dims": self.output_dims, "mapping": self.mapping} def prepare_data(self, *args, **kwargs) -> None: """Take the first steps to prepare data for use. Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`). """ def setup(self, stage: Optional[str] = None) -> None: """Perform final setup to prepare data for consumption by DataLoader. Here is where we typically split into train, validation, and test. This is done once per GPU in a DDP setting. Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test. """ def train_dataloader(self): return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def val_dataloader(self): return DataLoader( self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) def test_dataloader(self): return DataLoader( self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.on_gpu, ) ================================================ FILE: lab08/text_recognizer/data/emnist.py ================================================ """EMNIST dataset. Downloads from NIST website and saves as .npz file if not already present.""" import json import os from pathlib import Path import shutil from typing import Sequence import zipfile import h5py import numpy as np import toml from text_recognizer.data.base_data_module import _download_raw_dataset, BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset, split_dataset import text_recognizer.metadata.emnist as metadata from text_recognizer.stems.image import ImageStem from text_recognizer.util import temporary_working_directory NUM_SPECIAL_TOKENS = metadata.NUM_SPECIAL_TOKENS RAW_DATA_DIRNAME = metadata.RAW_DATA_DIRNAME METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME PROCESSED_DATA_FILENAME = metadata.PROCESSED_DATA_FILENAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME SAMPLE_TO_BALANCE = True # If true, take at most the mean number of instances per class. TRAIN_FRAC = 0.8 class EMNIST(BaseDataModule): """EMNIST dataset of handwritten characters and digits. "The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset." From https://www.nist.gov/itl/iad/image-group/emnist-dataset The data split we will use is EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ def __init__(self, args=None): super().__init__(args) self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.transform = ImageStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS def prepare_data(self, *args, **kwargs) -> None: if not os.path.exists(PROCESSED_DATA_FILENAME): _download_and_process_emnist() def setup(self, stage: str = None) -> None: if stage == "fit" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_trainval = f["x_train"][:] self.y_trainval = f["y_train"][:].squeeze().astype(int) data_trainval = BaseDataset(self.x_trainval, self.y_trainval, transform=self.transform) self.data_train, self.data_val = split_dataset(base_dataset=data_trainval, fraction=TRAIN_FRAC, seed=42) if stage == "test" or stage is None: with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: self.x_test = f["x_test"][:] self.y_test = f["y_test"][:].squeeze().astype(int) self.data_test = BaseDataset(self.x_test, self.y_test, transform=self.transform) def __repr__(self): basic = f"EMNIST Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\nDims: {self.input_dims}\n" if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def _download_and_process_emnist(): metadata = toml.load(METADATA_FILENAME) _download_raw_dataset(metadata, DL_DATA_DIRNAME) _process_raw_dataset(metadata["filename"], DL_DATA_DIRNAME) def _process_raw_dataset(filename: str, dirname: Path): print("Unzipping EMNIST...") with temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zf: zf.extract("matlab/emnist-byclass.mat") from scipy.io import loadmat # NOTE: If importing at the top of module, would need to list scipy as prod dependency. print("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = data["dataset"]["train"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_train = data["dataset"]["train"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS x_test = data["dataset"]["test"][0, 0]["images"][0, 0].reshape(-1, 28, 28).swapaxes(1, 2) y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS # NOTE that we add NUM_SPECIAL_TOKENS to targets, since these tokens are the first class indices if SAMPLE_TO_BALANCE: print("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) print("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") f.create_dataset("y_train", data=y_train, dtype="u1", compression="lzf") f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") print("Saving essential dataset parameters to text_recognizer/data...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(list(mapping.values())) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} with open(ESSENTIALS_FILENAME, "w") as f: json.dump(essentials, f) print("Cleaning up...") shutil.rmtree("matlab") def _sample_to_balance(x, y): """Because the dataset is not balanced, we take at most the mean number of instances per class.""" np.random.seed(42) num_to_sample = int(np.bincount(y.flatten()).mean()) all_sampled_inds = [] for label in np.unique(y.flatten()): inds = np.where(y == label)[0] sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) all_sampled_inds.append(sampled_inds) ind = np.concatenate(all_sampled_inds) x_sampled = x[ind] y_sampled = y[ind] return x_sampled, y_sampled def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: """Augment the mapping with extra symbols.""" # Extra characters from the IAM dataset iam_characters = [ " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] # Also add special tokens: # - CTC blank token at index 0 # - Start token at index 1 # - End token at index 2 # - Padding token at index 3 # NOTE: Don't forget to update NUM_SPECIAL_TOKENS if changing this! return ["", "", "", "

", *characters, *iam_characters] if __name__ == "__main__": load_and_print_info(EMNIST) ================================================ FILE: lab08/text_recognizer/data/emnist_essentials.json ================================================ {"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} ================================================ FILE: lab08/text_recognizer/data/emnist_lines.py ================================================ import argparse from collections import defaultdict from typing import Dict, Sequence import h5py import numpy as np import torch from text_recognizer.data import EMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.util import BaseDataset import text_recognizer.metadata.emnist_lines as metadata from text_recognizer.stems.image import ImageStem PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME ESSENTIALS_FILENAME = metadata.ESSENTIALS_FILENAME DEFAULT_MAX_LENGTH = 32 DEFAULT_MIN_OVERLAP = 0 DEFAULT_MAX_OVERLAP = 0.33 NUM_TRAIN = 10000 NUM_VAL = 2000 NUM_TEST = 2000 class EMNISTLines(BaseDataModule): """EMNIST Lines dataset: synthetic handwriting lines dataset made from EMNIST characters.""" def __init__( self, args: argparse.Namespace = None, ): super().__init__(args) self.max_length = self.args.get("max_length", DEFAULT_MAX_LENGTH) self.min_overlap = self.args.get("min_overlap", DEFAULT_MIN_OVERLAP) self.max_overlap = self.args.get("max_overlap", DEFAULT_MAX_OVERLAP) self.num_train = self.args.get("num_train", NUM_TRAIN) self.num_val = self.args.get("num_val", NUM_VAL) self.num_test = self.args.get("num_test", NUM_TEST) self.with_start_end_tokens = self.args.get("with_start_end_tokens", False) self.mapping = metadata.MAPPING self.output_dims = (self.max_length, 1) max_width = metadata.CHAR_WIDTH * self.max_length self.input_dims = (*metadata.DIMS[:2], max_width) self.emnist = EMNIST() self.transform = ImageStem() @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument( "--max_length", type=int, default=DEFAULT_MAX_LENGTH, help=f"Max line length in characters. Default is {DEFAULT_MAX_LENGTH}", ) parser.add_argument( "--min_overlap", type=float, default=DEFAULT_MIN_OVERLAP, help=f"Min overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MIN_OVERLAP}", ) parser.add_argument( "--max_overlap", type=float, default=DEFAULT_MAX_OVERLAP, help=f"Max overlap between characters in a line, between 0 and 1. Default is {DEFAULT_MAX_OVERLAP}", ) parser.add_argument("--with_start_end_tokens", action="store_true", default=False) return parser @property def data_filename(self): return ( PROCESSED_DATA_DIRNAME / f"ml_{self.max_length}_o{self.min_overlap:f}_{self.max_overlap:f}_ntr{self.num_train}_ntv{self.num_val}_nte{self.num_test}_{self.with_start_end_tokens}.h5" ) def prepare_data(self, *args, **kwargs) -> None: if self.data_filename.exists(): return np.random.seed(42) self._generate_data("train") self._generate_data("val") self._generate_data("test") def setup(self, stage: str = None) -> None: print("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: with h5py.File(self.data_filename, "r") as f: x_train = f["x_train"][:] y_train = f["y_train"][:].astype(int) x_val = f["x_val"][:] y_val = f["y_val"][:].astype(int) self.data_train = BaseDataset(x_train, y_train, transform=self.transform) self.data_val = BaseDataset(x_val, y_val, transform=self.transform) if stage == "test" or stage is None: with h5py.File(self.data_filename, "r") as f: x_test = f["x_test"][:] y_test = f["y_test"][:].astype(int) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "EMNIST Lines Dataset\n" f"Min overlap: {self.min_overlap}\n" f"Max overlap: {self.max_overlap}\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Batch x stats: {(x.shape, x.dtype, x.min().item(), x.mean().item(), x.std().item(), x.max().item())}\n" f"Batch y stats: {(y.shape, y.dtype, y.min().item(), y.max().item())}\n" ) return basic + data def _generate_data(self, split: str) -> None: print(f"EMNISTLinesDataset generating data for {split}...") from text_recognizer.data.sentence_generator import SentenceGenerator sentence_generator = SentenceGenerator(self.max_length - 2) # Subtract two because we will add start/end tokens emnist = self.emnist emnist.prepare_data() emnist.setup() if split == "train": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_train elif split == "val": samples_by_char = get_samples_by_char(emnist.x_trainval, emnist.y_trainval, emnist.mapping) num = self.num_val else: samples_by_char = get_samples_by_char(emnist.x_test, emnist.y_test, emnist.mapping) num = self.num_test PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(self.data_filename, "a") as f: x, y = create_dataset_of_images( num, samples_by_char, sentence_generator, self.min_overlap, self.max_overlap, self.input_dims ) y = convert_strings_to_labels( y, emnist.inverse_mapping, length=self.output_dims[0], with_start_end_tokens=self.with_start_end_tokens, ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def get_samples_by_char(samples, labels, mapping): samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): samples_by_char[mapping[label]].append(sample) return samples_by_char def select_letter_samples_for_string(string, samples_by_char, char_shape=(metadata.CHAR_HEIGHT, metadata.CHAR_WIDTH)): zero_image = torch.zeros(char_shape, dtype=torch.uint8) sample_image_by_char = {} for char in string: if char in sample_image_by_char: continue samples = samples_by_char[char] sample = samples[np.random.choice(len(samples))] if samples else zero_image sample_image_by_char[char] = sample.reshape(*char_shape) return [sample_image_by_char[char] for char in string] def construct_image_from_string( string: str, samples_by_char: dict, min_overlap: float, max_overlap: float, width: int ) -> torch.Tensor: overlap = np.random.uniform(min_overlap, max_overlap) sampled_images = select_letter_samples_for_string(string, samples_by_char) H, W = sampled_images[0].shape next_overlap_width = W - int(overlap * W) concatenated_image = torch.zeros((H, width), dtype=torch.uint8) x = 0 for image in sampled_images: concatenated_image[:, x : (x + W)] += image x += next_overlap_width return torch.minimum(torch.Tensor([255]), concatenated_image) def create_dataset_of_images(N, samples_by_char, sentence_generator, min_overlap, max_overlap, dims): images = torch.zeros((N, dims[1], dims[2])) labels = [] for n in range(N): label = sentence_generator.generate() images[n] = construct_image_from_string(label, samples_by_char, min_overlap, max_overlap, dims[-1]) labels.append(label) return images, labels def convert_strings_to_labels( strings: Sequence[str], mapping: Dict[str, int], length: int, with_start_end_tokens: bool ) -> np.ndarray: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = np.ones((len(strings), length), dtype=np.uint8) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) if with_start_end_tokens: tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels if __name__ == "__main__": load_and_print_info(EMNISTLines) ================================================ FILE: lab08/text_recognizer/data/fake_images.py ================================================ """A fake image dataset for testing.""" import argparse import torch import torchvision from text_recognizer.data.base_data_module import BaseDataModule _NUM_SAMPLES = 512 _IMAGE_LEN = 28 _NUM_CLASSES = 10 class FakeImageData(BaseDataModule): """Fake images dataset.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.num_samples = self.args.get("num_samples", _NUM_SAMPLES) self.input_dims = (1, self.args.get("image_height", _IMAGE_LEN), self.args.get("image_width", _IMAGE_LEN)) self.num_classes = self.args.get("num_classes", _NUM_CLASSES) self.output_dims = (self.num_classes, 1) self.mapping = list(range(0, self.num_classes)) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--num_samples", type=int, default=_NUM_SAMPLES) parser.add_argument("--num_classes", type=int, default=_NUM_CLASSES) parser.add_argument("--image_height", type=int, default=_IMAGE_LEN) parser.add_argument("--image_width", type=int, default=_IMAGE_LEN) return parser def setup(self, stage: str = None) -> None: fake_dataset = torchvision.datasets.FakeData( size=self.num_samples, image_size=self.input_dims, num_classes=self.output_dims[0], transform=torchvision.transforms.ToTensor(), ) val_size = int(self.num_samples * 0.25) self.data_train, self.data_val, self.data_test = torch.utils.data.random_split( # type: ignore dataset=fake_dataset, lengths=[self.num_samples - 2 * val_size, val_size, val_size] ) ================================================ FILE: lab08/text_recognizer/data/iam.py ================================================ """Class for loading the IAM handwritten text dataset, which encompasses both paragraphs and lines, plus utilities.""" from pathlib import Path from typing import Any, cast, Dict, List, Optional import zipfile from boltons.cacheutils import cachedproperty from defusedxml import ElementTree from PIL import Image, ImageOps import toml from text_recognizer import util from text_recognizer.data.base_data_module import _download_raw_dataset, load_and_print_info import text_recognizer.metadata.iam as metadata from text_recognizer.metadata.iam_paragraphs import NEW_LINE_TOKEN METADATA_FILENAME = metadata.METADATA_FILENAME DL_DATA_DIRNAME = metadata.DL_DATA_DIRNAME EXTRACTED_DATASET_DIRNAME = metadata.EXTRACTED_DATASET_DIRNAME class IAM: """A dataset of images of handwritten text written on a form underneath a typewritten prompt. "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels." From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database Images are identified by their "form ID". These IDs are used to separate train, validation and test splits, as keys for dictonaries returning label and image crop region data, and more. The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (LWITLRT): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self): self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self): if self.xml_filenames: return filename = _download_raw_dataset(self.metadata, DL_DATA_DIRNAME) # type: ignore _extract_raw_dataset(filename, DL_DATA_DIRNAME) def load_image(self, id: str) -> Image.Image: """Load and return an image of an entire IAM form. The image is grayscale with white text on black background. This image will have the printed prompt text at the top, above the handwritten text. Images of individual words or lines and of whole paragraphs can be cropped out using the relevant crop region data. """ image = util.read_image_pil(self.form_filenames_by_id[id], grayscale=True) image = ImageOps.invert(image) return image def __repr__(self): """Print info about the dataset.""" info = ["IAM Dataset"] info.append(f"Total Images: {len(self.xml_filenames)}") info.append(f"Total Test Images: {len(self.test_ids)}") info.append(f"Total Paragraphs: {len(self.paragraph_string_by_id)}") num_lines = sum(len(line_regions) for line_regions in self.line_regions_by_id.items()) info.append(f"Total Lines: {num_lines}") return "\n\t".join(info) @cachedproperty def all_ids(self): """A list of all form IDs.""" return sorted([f.stem for f in self.xml_filenames]) @cachedproperty def ids_by_split(self): return {"train": self.train_ids, "val": self.validation_ids, "test": self.test_ids} @cachedproperty def split_by_id(self): """A dictionary mapping form IDs to their split according to IAM Lines LWITLRT.""" split_by_id = {id_: "train" for id_ in self.train_ids} split_by_id.update({id_: "val" for id_ in self.validation_ids}) split_by_id.update({id_: "test" for id_ in self.test_ids}) return split_by_id @cachedproperty def train_ids(self): """A list of form IDs which are in the IAM Lines LWITLRT training set.""" return list(set(self.all_ids) - (set(self.test_ids) | set(self.validation_ids))) @cachedproperty def test_ids(self): """A list of form IDs from the IAM Lines LWITLRT test set.""" return _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/testset.txt") @property def xml_filenames(self) -> List[Path]: """A list of the filenames of all .xml files, which contain label information.""" return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @cachedproperty def validation_ids(self): """A list of form IDs from IAM Lines LWITLRT validation sets 1 and 2.""" val_ids = _get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset1.txt") val_ids.extend(_get_ids_from_lwitlrt_split_file(EXTRACTED_DATASET_DIRNAME / "task/validationset2.txt")) return val_ids @property def form_filenames(self) -> List[Path]: """A list of the filenames of all .jpg files, which contain images of IAM forms.""" return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def xml_filenames_by_id(self): """A dictionary mapping form IDs to their XML label information files.""" return {filename.stem: filename for filename in self.xml_filenames} @property def form_filenames_by_id(self): """A dictionary mapping form IDs to their JPEG images.""" return {filename.stem: filename for filename in self.form_filenames} @cachedproperty def line_strings_by_id(self): """A dict mapping an IAM form id to its list of line texts.""" return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def line_regions_by_id(self): """A dict mapping an IAM form id to its list of line image crop regions.""" return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} @cachedproperty def paragraph_string_by_id(self): """A dict mapping an IAM form id to its paragraph text.""" return {id: NEW_LINE_TOKEN.join(line_strings) for id, line_strings in self.line_strings_by_id.items()} @cachedproperty def paragraph_region_by_id(self): """A dict mapping an IAM form id to its paragraph image crop region.""" return { id: { "x1": min(region["x1"] for region in line_regions), "y1": min(region["y1"] for region in line_regions), "x2": max(region["x2"] for region in line_regions), "y2": max(region["y2"] for region in line_regions), } for id, line_regions in self.line_regions_by_id.items() } def _extract_raw_dataset(filename: Path, dirname: Path) -> None: print("Extracting IAM data") with util.temporary_working_directory(dirname): with zipfile.ZipFile(filename, "r") as zip_file: zip_file.extractall() def _get_ids_from_lwitlrt_split_file(filename: str) -> List[str]: """Get the ids from Large Writer Independent Text Line Recognition Task (LWITLRT) data split file.""" with open(filename, "r") as f: line_ids_str = f.read() line_ids = line_ids_str.split("\n") page_ids = list({"-".join(line_id.split("-")[:2]) for line_id in line_ids if line_id}) return page_ids def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace " with ".""" xml_line_elements = _get_line_elements_from_xml_file(filename) return [_get_text_from_xml_element(el) for el in xml_line_elements] def _get_text_from_xml_element(xml_element: Any) -> str: """Extract text from any XML element.""" return xml_element.attrib["text"].replace(""", '"') def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get the line region dict for each line.""" xml_line_elements = _get_line_elements_from_xml_file(filename) line_regions = [ cast(Dict[str, int], _get_region_from_xml_element(xml_elem=el, xml_path="word/cmp")) for el in xml_line_elements ] assert any(region is not None for region in line_regions), "Line regions cannot be None" # next_line_region["y1"] - prev_line_region["y2"] can be negative due to overlapping characters line_gaps_y = [ max(next_line_region["y1"] - prev_line_region["y2"], 0) for next_line_region, prev_line_region in zip(line_regions[1:], line_regions[:-1]) ] post_line_gaps_y = line_gaps_y + [2 * metadata.LINE_REGION_PADDING] pre_line_gaps_y = [2 * metadata.LINE_REGION_PADDING] + line_gaps_y return [ { "x1": region["x1"] - metadata.LINE_REGION_PADDING, "x2": region["x2"] + metadata.LINE_REGION_PADDING, "y1": region["y1"] - min(metadata.LINE_REGION_PADDING, pre_line_gaps_y[i] // 2), "y2": region["y2"] + min(metadata.LINE_REGION_PADDING, post_line_gaps_y[i] // 2), } for i, region in enumerate(line_regions) ] def _get_line_elements_from_xml_file(filename: str) -> List[Any]: """Get all line xml elements from xml file.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec return xml_root_element.findall("handwritten-part/line") def _get_region_from_xml_element(xml_elem: Any, xml_path: str) -> Optional[Dict[str, int]]: """ Get region from input xml element. The region is downsampled because the stored images are also downsampled. Parameters ---------- xml_elem xml element can be a line or word element with x, y, width, and height attributes xml_path should be "word/cmp" if xml_elem is a line element, else "cmp" """ unit_elements = xml_elem.findall(xml_path) if not unit_elements: return None return { "x1": min(int(el.attrib["x"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y1": min(int(el.attrib["y"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "x2": max(int(el.attrib["x"]) + int(el.attrib["width"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, "y2": max(int(el.attrib["y"]) + int(el.attrib["height"]) for el in unit_elements) // metadata.DOWNSAMPLE_FACTOR, } if __name__ == "__main__": load_and_print_info(IAM) ================================================ FILE: lab08/text_recognizer/data/iam_lines.py ================================================ """A dataset of lines of handwritten text derived from the IAM dataset.""" import argparse import json from pathlib import Path from typing import Sequence import numpy as np from PIL import Image, ImageFile from text_recognizer import util from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.line import IAMLineStem ImageFile.LOAD_TRUNCATED_IMAGES = True PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR class IAMLines(BaseDataModule): """Lines of text pulled from the IAM Handwriting database.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true") == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = IAMLineStem() self.trainval_transform = IAMLineStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if PROCESSED_DATA_DIRNAME.exists(): return print("Cropping IAM line regions...") iam = IAM() iam.prepare_data() crops_train, labels_train = generate_line_crops_and_labels(iam, "train") crops_val, labels_val = generate_line_crops_and_labels(iam, "val") crops_test, labels_test = generate_line_crops_and_labels(iam, "test") shapes = np.array([crop.size for crop in crops_train + crops_val + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] print("Saving images, labels, and statistics...") save_images_and_labels(crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_val, labels_val, "val", PROCESSED_DATA_DIRNAME) save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt", "w") as file: file.write(str(aspect_ratios.max())) def setup(self, stage: str = None) -> None: with open(PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt") as file: max_aspect_ratio = float(file.read()) image_width = int(metadata.IMAGE_HEIGHT * max_aspect_ratio) assert image_width <= metadata.IMAGE_WIDTH if stage == "fit" or stage is None: x_train, labels_train = load_processed_crops_and_labels("train", PROCESSED_DATA_DIRNAME) y_train = convert_strings_to_labels(labels_train, self.inverse_mapping, length=self.output_dims[0]) self.data_train = BaseDataset(x_train, y_train, transform=self.trainval_transform) x_val, labels_val = load_processed_crops_and_labels("val", PROCESSED_DATA_DIRNAME) y_val = convert_strings_to_labels(labels_val, self.inverse_mapping, length=self.output_dims[0]) self.data_val = BaseDataset(x_val, y_val, transform=self.trainval_transform) # quick check: do we have the right sequence lengths? assert self.output_dims[0] >= max([len(_) for _ in labels_train]) + 2 # Add 2 for start/end tokens. assert self.output_dims[0] >= max([len(_) for _ in labels_val]) + 2 # Add 2 for start/end tokens. if stage == "test" or stage is None: x_test, labels_test = load_processed_crops_and_labels("test", PROCESSED_DATA_DIRNAME) y_test = convert_strings_to_labels(labels_test, self.inverse_mapping, length=self.output_dims[0]) self.data_test = BaseDataset(x_test, y_test, transform=self.transform) assert self.output_dims[0] >= max([len(_) for _ in labels_test]) + 2 def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Lines Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def generate_line_crops_and_labels(iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR): """Create both cropped lines and associated labels from IAM, with resizing by default""" crops, labels = [], [] for iam_id in iam.ids_by_split[split]: labels += iam.line_strings_by_id[iam_id] image = iam.load_image(iam_id) for line in iam.line_regions_by_id[iam_id]: coords = [line[point] for point in ["x1", "y1", "x2", "y2"]] crop = image.crop(coords) crop = resize_image(crop, scale_factor=scale_factor) crops.append(crop) assert len(crops) == len(labels) return crops, labels def save_images_and_labels(crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path): (data_dirname / split).mkdir(parents=True, exist_ok=True) with open(data_dirname / split / "_labels.json", "w") as f: json.dump(labels, f) for ind, crop in enumerate(crops): crop.save(data_dirname / split / f"{ind}.png") def load_processed_crops_and_labels(split: str, data_dirname: Path): """Load line crops and labels for given split from processed directory.""" crops = load_processed_line_crops(split, data_dirname) labels = load_processed_line_labels(split, data_dirname) assert len(crops) == len(labels) return crops, labels def load_processed_line_crops(split: str, data_dirname: Path): """Load line crops for given split from processed directory.""" crop_filenames = sorted((data_dirname / split).glob("*.png"), key=lambda filename: int(Path(filename).stem)) crops = [util.read_image_pil(filename, grayscale=True) for filename in crop_filenames] return crops def load_processed_line_labels(split: str, data_dirname: Path): """Load line labels for given split from processed directory.""" with open(data_dirname / split / "_labels.json") as file: labels = json.load(file) return labels if __name__ == "__main__": load_and_print_info(IAMLines) ================================================ FILE: lab08/text_recognizer/data/iam_original_and_synthetic_paragraphs.py ================================================ """IAM Original and Synthetic Paragraphs Dataset class.""" import argparse from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs class IAMOriginalAndSyntheticParagraphs(BaseDataModule): """A concatenation of original and synthetic IAM paragraph datasets.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.iam_paragraphs = IAMParagraphs(args) self.iam_syn_paragraphs = IAMSyntheticParagraphs(args) self.input_dims = self.iam_paragraphs.input_dims self.output_dims = self.iam_paragraphs.output_dims self.mapping = self.iam_paragraphs.mapping self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") IAMSyntheticParagraphs.add_to_argparse(parser) return parser def prepare_data(self, *args, **kwargs) -> None: self.iam_paragraphs.prepare_data() self.iam_syn_paragraphs.prepare_data() def setup(self, stage: str = None) -> None: self.iam_paragraphs.setup(stage) self.iam_syn_paragraphs.setup(stage) if stage == "fit" or stage is None: self.data_train = ConcatDataset([self.iam_paragraphs.data_train, self.iam_syn_paragraphs.data_train]) self.data_val = self.iam_paragraphs.data_val if stage == "test" or stage is None: self.data_test = self.iam_paragraphs.data_test def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Dims: {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data if __name__ == "__main__": load_and_print_info(IAMOriginalAndSyntheticParagraphs) ================================================ FILE: lab08/text_recognizer/data/iam_paragraphs.py ================================================ """IAM Paragraphs Dataset class.""" import argparse import json from pathlib import Path from typing import Callable, Dict, Optional, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.util import BaseDataset, convert_strings_to_labels, resize_image import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.paragraph import ParagraphStem IMAGE_SCALE_FACTOR = metadata.IMAGE_SCALE_FACTOR MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME class IAMParagraphs(BaseDataModule): """IAM Handwriting database paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.augment = self.args.get("augment_data", "true").lower() == "true" self.mapping = metadata.MAPPING self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} self.input_dims = metadata.DIMS # We assert that this is correct in setup() self.output_dims = metadata.OUTPUT_DIMS # We assert that this is correct in setup() self.transform = ParagraphStem() self.trainval_transform = ParagraphStem(augment=self.augment) @staticmethod def add_to_argparse(parser): BaseDataModule.add_to_argparse(parser) parser.add_argument("--augment_data", type=str, default="true") return parser def prepare_data(self, *args, **kwargs) -> None: if (PROCESSED_DATA_DIRNAME / "_properties.json").exists(): return rank_zero_info( "IAMParagraphs.prepare_data: Cropping IAM paragraph regions and saving them along with labels..." ) iam = IAM() iam.prepare_data() properties = {} for split in ["train", "val", "test"]: crops, labels = get_paragraph_crops_and_labels(iam=iam, split=split) save_crops_and_labels(crops=crops, labels=labels, split=split) properties.update( { id_: { "crop_shape": crops[id_].size[::-1], "label_length": len(label), "num_lines": _num_lines(label), } for id_, label in labels.items() } ) with open(PROCESSED_DATA_DIRNAME / "_properties.json", "w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: def _load_dataset(split: str, transform: Callable) -> BaseDataset: crops, labels = load_processed_crops_and_labels(split) Y = convert_strings_to_labels(strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]) return BaseDataset(crops, Y, transform=transform) rank_zero_info(f"IAMParagraphs.setup({stage}): Loading IAM paragraph regions and lines...") validate_input_and_output_dimensions(input_dims=self.input_dims, output_dims=self.output_dims) if stage == "fit" or stage is None: self.data_train = _load_dataset(split="train", transform=self.trainval_transform) self.data_val = _load_dataset(split="val", transform=self.transform) if stage == "test" or stage is None: self.data_test = _load_dataset(split="test", transform=self.transform) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None and self.data_val is None and self.data_test is None: return basic x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" ) return basic + data def validate_input_and_output_dimensions( input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] ) -> None: """Validate input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR assert input_dims is not None and input_dims[1] >= max_image_shape[0] and input_dims[2] >= max_image_shape[1] # Add 2 because of start and end tokens assert output_dims is not None and output_dims[0] >= properties["label_length"]["max"] + 2 def get_paragraph_crops_and_labels( iam: IAM, split: str, scale_factor=IMAGE_SCALE_FACTOR ) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: """Create IAM paragraph crops and labels for a given split, with resizing.""" crops = {} labels = {} for iam_id in iam.ids_by_split[split]: image = iam.load_image(iam_id) para_region = iam.paragraph_region_by_id[iam_id] crops[iam_id] = image.crop([para_region[_] for _ in ["x1", "y1", "x2", "y2"]]) crops[iam_id] = resize_image(crops[iam_id], scale_factor=scale_factor) labels[iam_id] = iam.paragraph_string_by_id[iam_id] assert len(crops) == len(labels) return crops, labels def save_crops_and_labels(crops: Dict[str, Image.Image], labels: Dict[str, str], split: str): """Save crops, labels and shapes of crops of a split.""" (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with open(_labels_filename(split), "w") as f: json.dump(labels, f, indent=4) for id_, crop in crops.items(): crop.save(_crop_filename(id_, split)) def load_processed_crops_and_labels(split: str) -> Tuple[Sequence[Image.Image], Sequence[str]]: """Load processed crops and labels for given split.""" with open(_labels_filename(split), "r") as f: labels = json.load(f) sorted_ids = sorted(labels.keys()) ordered_crops = [Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids] ordered_labels = [labels[id_] for id_ in sorted_ids] assert len(ordered_crops) == len(ordered_labels) return ordered_crops, ordered_labels def get_dataset_properties() -> dict: """Return properties describing the overall dataset.""" with open(PROCESSED_DATA_DIRNAME / "_properties.json", "r") as f: properties = json.load(f) def _get_property_values(key: str) -> list: return [_[key] for _ in properties.values()] crop_shapes = np.array(_get_property_values("crop_shape")) aspect_ratios = crop_shapes[:, 1] / crop_shapes[:, 0] return { "label_length": { "min": min(_get_property_values("label_length")), "max": max(_get_property_values("label_length")), }, "num_lines": {"min": min(_get_property_values("num_lines")), "max": max(_get_property_values("num_lines"))}, "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0)}, "aspect_ratio": {"min": aspect_ratios.min(), "max": aspect_ratios.max()}, } def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id_: str, split: str) -> Path: """Return filename of processed crop.""" return PROCESSED_DATA_DIRNAME / split / f"{id_}.png" def _num_lines(label: str) -> int: """Return number of lines of text in label.""" return label.count(NEW_LINE_TOKEN) + 1 if __name__ == "__main__": load_and_print_info(IAMParagraphs) ================================================ FILE: lab08/text_recognizer/data/iam_synthetic_paragraphs.py ================================================ """IAM Synthetic Paragraphs Dataset class.""" import argparse import random from typing import Any, Callable, List, Sequence, Tuple import numpy as np from PIL import Image from pytorch_lightning.utilities.rank_zero import rank_zero_info import torch from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( generate_line_crops_and_labels, load_processed_line_crops, load_processed_line_labels, save_images_and_labels, ) from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.util import convert_strings_to_labels import text_recognizer.metadata.iam_synthetic_paragraphs as metadata NEW_LINE_TOKEN = metadata.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = metadata.PROCESSED_DATA_DIRNAME DATASET_LEN = metadata.DATASET_LEN class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database synthetic paragraphs.""" def __init__(self, args: argparse.Namespace = None): super().__init__(args) self.line_crops = None self.line_labels = None self.dataset_len = self.args.get("dataset_len", DATASET_LEN) def prepare_data(self, *args, **kwargs) -> None: """ Prepare IAM lines such that they can be used to generate synthetic paragraphs dataset in setup(). This method is IAMLines.prepare_data + resizing of line crops. """ if PROCESSED_DATA_DIRNAME.exists(): return rank_zero_info( "IAMSyntheticParagraphs.prepare_data: preparing IAM lines for synthetic IAM paragraph creation..." ) iam = IAM() iam.prepare_data() for split in ["train"]: # synthetic dataset is only used in training phase rank_zero_info(f"Cropping IAM line regions and loading labels for {split} data split...") crops, labels = generate_line_crops_and_labels(iam, split) save_images_and_labels(crops, labels, split, PROCESSED_DATA_DIRNAME) def setup(self, stage: str = None) -> None: rank_zero_info(f"IAMSyntheticParagraphs.setup({stage}): Loading train IAM paragraph regions and lines...") if stage == "fit" or stage is None: self._load_processed_crops_and_labels() self.data_train = IAMSyntheticParagraphsDataset( line_crops=self.line_crops, line_labels=self.line_labels, dataset_len=self.dataset_len, inverse_mapping=self.inverse_mapping, input_dims=self.input_dims, output_dims=self.output_dims, transform=self.trainval_transform, ) def _load_processed_crops_and_labels(self): if self.line_crops is None: self.line_crops = load_processed_line_crops("train", PROCESSED_DATA_DIRNAME) if self.line_labels is None: self.line_labels = load_processed_line_labels("train", PROCESSED_DATA_DIRNAME) def __repr__(self) -> str: """Print info about the dataset.""" basic = ( "IAM Synthetic Paragraphs Dataset\n" f"Num classes: {len(self.mapping)}\n" f"Input dims : {self.input_dims}\n" f"Output dims: {self.output_dims}\n" ) if self.data_train is None: return basic x, y = next(iter(self.train_dataloader())) data = ( f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" ) return basic + data def add_to_argparse(parser): parser.add_argument("--dataset_len", type=int, default=DATASET_LEN) return parser class IAMSyntheticParagraphsDataset(torch.utils.data.Dataset): """Dataset of synthetic paragraphs built out of individual IAM lines.""" def __init__( self, line_crops: List[Image.Image], line_labels: List[str], dataset_len: int, inverse_mapping: dict, input_dims: Tuple[int, ...], output_dims: Tuple[int, ...], transform: Callable = None, ) -> None: super().__init__() self.line_crops = line_crops self.line_labels = line_labels assert len(self.line_crops) == len(self.line_labels) self.ids = list(range(len(self.line_labels))) self.dataset_len = dataset_len self.inverse_mapping = inverse_mapping self.input_dims = input_dims self.output_dims = output_dims self.transform = transform self.min_num_lines, self.max_num_lines = 1, 15 self.seed_set = False def __len__(self) -> int: """Return length of the dataset.""" return self.dataset_len def _set_seed(self, seed): if not self.seed_set: print(f"Setting seed to {seed} for worker {torch.utils.data.get_worker_info()}") random.seed(seed) self.seed_set = True def __getitem__(self, index: int) -> Tuple[Any, Any]: """Return a random paragraph, using the first index as a seed.""" # Since shuffle is True for train dataloaders, the first index will be different on different GPUs self._set_seed(index) num_lines = random.randint(self.min_num_lines, self.max_num_lines) indices = random.sample(self.ids, k=num_lines) while True: datum = join_line_crops_to_form_paragraph([self.line_crops[i] for i in indices]) labels = NEW_LINE_TOKEN.join([self.line_labels[i] for i in indices]) if ( (len(labels) <= self.output_dims[0] - 2) and (datum.height <= self.input_dims[1]) and (datum.width <= self.input_dims[2]) ): break indices = indices[:-1] if self.transform is not None: datum = self.transform(datum) length = self.output_dims[0] target = convert_strings_to_labels(strings=[labels], mapping=self.inverse_mapping, length=length)[0] return datum, target def join_line_crops_to_form_paragraph(line_crops: Sequence[Image.Image]) -> Image.Image: """Horizontally stack line crops and return a single image forming the paragraph.""" crop_shapes = np.array([_.size[::-1] for _ in line_crops]) para_height = crop_shapes[:, 0].sum() para_width = crop_shapes[:, 1].max() para_image = Image.new(mode="L", size=(para_width, para_height), color=0) current_height = 0 for line_crop in line_crops: para_image.paste(line_crop, box=(0, current_height)) current_height += line_crop.height return para_image if __name__ == "__main__": load_and_print_info(IAMSyntheticParagraphs) ================================================ FILE: lab08/text_recognizer/data/mnist.py ================================================ """MNIST DataModule.""" import argparse from torch.utils.data import random_split from torchvision.datasets import MNIST as TorchMNIST from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info import text_recognizer.metadata.mnist as metadata from text_recognizer.stems.image import MNISTStem class MNIST(BaseDataModule): """MNIST DataModule.""" def __init__(self, args: argparse.Namespace) -> None: super().__init__(args) self.data_dir = metadata.DOWNLOADED_DATA_DIRNAME self.transform = MNISTStem() self.input_dims = metadata.DIMS self.output_dims = metadata.OUTPUT_DIMS self.mapping = metadata.MAPPING def prepare_data(self, *args, **kwargs) -> None: """Download train and test MNIST data from PyTorch canonical source.""" TorchMNIST(self.data_dir, train=True, download=True) TorchMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None) -> None: """Split into train, val, test, and set dims.""" mnist_full = TorchMNIST(self.data_dir, train=True, transform=self.transform) self.data_train, self.data_val = random_split(mnist_full, [metadata.TRAIN_SIZE, metadata.VAL_SIZE]) # type: ignore self.data_test = TorchMNIST(self.data_dir, train=False, transform=self.transform) if __name__ == "__main__": load_and_print_info(MNIST) ================================================ FILE: lab08/text_recognizer/data/sentence_generator.py ================================================ """SentenceGenerator class and supporting functions.""" import itertools import re import string from typing import List, Optional import nltk import numpy as np from text_recognizer.data.base_data_module import BaseDataModule NLTK_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "nltk" class SentenceGenerator: """Generate text sentences using the Brown corpus.""" def __init__(self, max_length: Optional[int] = None): self.text = brown_text() self.word_start_inds = [0] + [_.start(0) + 1 for _ in re.finditer(" ", self.text)] self.max_length = max_length def generate(self, max_length: Optional[int] = None) -> str: """Sample a string from text of the Brown corpus of length at least one word and at most max_length.""" if max_length is None: max_length = self.max_length if max_length is None: raise ValueError("Must provide max_length to this method or when making this object.") sampled_text, num_tries = None, 0 while (not sampled_text) and (num_tries <= 10): # try several times to generate sample text first_ind = np.random.randint(0, len(self.word_start_inds) - 1) start_ind = self.word_start_inds[first_ind] end_ind_candidates = self._get_end_ind_candidates(first_ind, start_ind, max_length) if len(end_ind_candidates) == 0: # sampling failed, try again num_tries += 1 continue else: end_ind = np.random.choice(end_ind_candidates) sampled_text = self.text[start_ind:end_ind].strip() if sampled_text is not None: return sampled_text else: raise RuntimeError("Was not able to generate a valid string") def _get_end_ind_candidates(self, first_ind: int, start_ind: int, max_length: int) -> List[int]: end_ind_candidates = [] for ind in range(first_ind + 1, len(self.word_start_inds)): if self.word_start_inds[ind] - start_ind > max_length: break end_ind_candidates.append(self.word_start_inds[ind]) return end_ind_candidates def brown_text(): """Return a single string with the Brown corpus with all punctuation stripped.""" sents = load_nltk_brown_corpus() text = " ".join(itertools.chain.from_iterable(sents)) text = text.translate({ord(c): None for c in string.punctuation}) text = re.sub(" +", " ", text) return text def load_nltk_brown_corpus(): """Load the Brown corpus using the NLTK library.""" nltk.data.path.append(NLTK_DATA_DIRNAME) try: nltk.corpus.brown.sents() except LookupError: NLTK_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) nltk.download("brown", download_dir=NLTK_DATA_DIRNAME) return nltk.corpus.brown.sents() ================================================ FILE: lab08/text_recognizer/data/util.py ================================================ """Base Dataset class.""" from typing import Any, Callable, Dict, Sequence, Tuple, Union from PIL import Image import torch SequenceOrTensor = Union[Sequence, torch.Tensor] class BaseDataset(torch.utils.data.Dataset): """Base Dataset class that simply processes data and targets through optional transforms. Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset Parameters ---------- data commonly these are torch tensors, numpy arrays, or PIL Images targets commonly these are torch tensors or numpy arrays transform function that takes a datum and returns the same target_transform function that takes a target and returns the same """ def __init__( self, data: SequenceOrTensor, targets: SequenceOrTensor, transform: Callable = None, target_transform: Callable = None, ) -> None: if len(data) != len(targets): raise ValueError("Data and targets must be of equal length") super().__init__() self.data = data self.targets = targets self.transform = transform self.target_transform = target_transform def __len__(self) -> int: """Return length of the dataset.""" return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Return a datum and its target, after processing by transforms. Parameters ---------- index Returns ------- (datum, target) """ datum, target = self.data[index], self.targets[index] if self.transform is not None: datum = self.transform(datum) if self.target_transform is not None: target = self.target_transform(target) return datum, target def convert_strings_to_labels(strings: Sequence[str], mapping: Dict[str, int], length: int) -> torch.Tensor: """ Convert sequence of N strings to a (N, length) ndarray, with each string wrapped with and tokens, and padded with the

token. """ labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["

"] for i, string in enumerate(strings): tokens = list(string) tokens = ["", *tokens, ""] for ii, token in enumerate(tokens): labels[i, ii] = mapping[token] return labels def split_dataset(base_dataset: BaseDataset, fraction: float, seed: int) -> Tuple[BaseDataset, BaseDataset]: """ Split input base_dataset into 2 base datasets, the first of size fraction * size of the base_dataset and the other of size (1 - fraction) * size of the base_dataset. """ split_a_size = int(fraction * len(base_dataset)) split_b_size = len(base_dataset) - split_a_size return torch.utils.data.random_split( # type: ignore base_dataset, [split_a_size, split_b_size], generator=torch.Generator().manual_seed(seed) ) def resize_image(image: Image.Image, scale_factor: int) -> Image.Image: """Resize image by scale factor.""" if scale_factor == 1: return image return image.resize((image.width // scale_factor, image.height // scale_factor), resample=Image.BILINEAR) ================================================ FILE: lab08/text_recognizer/lit_models/__init__.py ================================================ from .base import BaseLitModel from .transformer import TransformerLitModel ================================================ FILE: lab08/text_recognizer/lit_models/base.py ================================================ """Basic LightningModules on which other modules can be built.""" import argparse import pytorch_lightning as pl import torch from torchmetrics import Accuracy from .metrics import CharacterErrorRate OPTIMIZER = "Adam" LR = 1e-3 LOSS = "cross_entropy" ONE_CYCLE_TOTAL_STEPS = 100 class BaseLitModel(pl.LightningModule): """ Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ def __init__(self, model, args: argparse.Namespace = None): super().__init__() self.model = model self.args = vars(args) if args is not None else {} self.data_config = self.model.data_config self.mapping = self.data_config["mapping"] self.input_dims = self.data_config["input_dims"] optimizer = self.args.get("optimizer", OPTIMIZER) self.optimizer_class = getattr(torch.optim, optimizer) self.lr = self.args.get("lr", LR) loss = self.args.get("loss", LOSS) if loss not in ("transformer",): self.loss_fn = getattr(torch.nn.functional, loss) self.one_cycle_max_lr = self.args.get("one_cycle_max_lr", None) self.one_cycle_total_steps = self.args.get("one_cycle_total_steps", ONE_CYCLE_TOTAL_STEPS) self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() @staticmethod def add_to_argparse(parser): parser.add_argument("--optimizer", type=str, default=OPTIMIZER, help="optimizer class from torch.optim") parser.add_argument("--lr", type=float, default=LR) parser.add_argument("--one_cycle_max_lr", type=float, default=None) parser.add_argument("--one_cycle_total_steps", type=int, default=ONE_CYCLE_TOTAL_STEPS) parser.add_argument("--loss", type=str, default=LOSS, help="loss function from torch.nn.functional") return parser def configure_optimizers(self): optimizer = self.optimizer_class(self.parameters(), lr=self.lr) if self.one_cycle_max_lr is None: return optimizer scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps ) return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "validation/loss"} def forward(self, x): return self.model(x) def predict(self, x): logits = self.model(x) return torch.argmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.train_acc(logits, y) self.log("train/loss", loss) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def _run_on_batch(self, batch, with_preds=False): x, y = batch logits = self(x) loss = self.loss_fn(logits, y) return x, y, logits, loss def validation_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.val_acc(logits, y) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) self.log("validation/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) outputs = {"loss": loss} self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y, logits, loss = self._run_on_batch(batch) self.test_acc(logits, y) self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def add_on_first_batch(self, metrics, outputs, batch_idx): if batch_idx == 0: outputs.update(metrics) def add_on_logged_batches(self, metrics, outputs): if self.is_logged_batch: outputs.update(metrics) def is_logged_batch(self): if self.trainer is None: return False else: return self.trainer._logger_connector.should_update_logs class BaseImageToTextLitModel(BaseLitModel): # pylint: disable=too-many-ancestors """Base class for ImageToText models in PyTorch Lightning.""" def __init__(self, model, args: argparse.Namespace = None): super().__init__(model, args) self.model = model self.args = vars(args) if args is not None else {} self.inverse_mapping = {val: ind for ind, val in enumerate(self.mapping)} self.start_index = self.inverse_mapping[""] self.end_index = self.inverse_mapping[""] self.padding_index = self.inverse_mapping["

"] self.ignore_tokens = [self.start_index, self.end_index, self.padding_index] self.val_cer = CharacterErrorRate(self.ignore_tokens) self.test_cer = CharacterErrorRate(self.ignore_tokens) ================================================ FILE: lab08/text_recognizer/lit_models/metrics.py ================================================ """Special-purpose metrics for tracking our model performance.""" from typing import Sequence import torch import torchmetrics class CharacterErrorRate(torchmetrics.CharErrorRate): """Character error rate metric, allowing for tokens to be ignored.""" def __init__(self, ignore_tokens: Sequence[int], *args): super().__init__(*args) self.ignore_tokens = set(ignore_tokens) def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore preds_l = [[t for t in pred if t not in self.ignore_tokens] for pred in preds.tolist()] targets_l = [[t for t in target if t not in self.ignore_tokens] for target in targets.tolist()] super().update(preds_l, targets_l) def test_character_error_rate(): metric = CharacterErrorRate([0, 1]) X = torch.tensor( [ [0, 2, 2, 3, 3, 1], # error will be 0 [0, 2, 1, 1, 1, 1], # error will be .75 [0, 2, 2, 4, 4, 1], # error will be .5 ] ) Y = torch.tensor( [ [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], ] ) metric(X, Y) assert metric.compute() == sum([0, 0.75, 0.5]) / 3 if __name__ == "__main__": test_character_error_rate() ================================================ FILE: lab08/text_recognizer/lit_models/transformer.py ================================================ """An encoder-decoder Transformer model""" from typing import List, Sequence import torch from .base import BaseImageToTextLitModel from .util import replace_after class TransformerLitModel(BaseImageToTextLitModel): """ Generic image to text PyTorch-Lightning module that must be initialized with a PyTorch module. The module must implement an encode and decode method, and the forward method should be the forward pass during production inference. """ def __init__(self, model, args=None): super().__init__(model, args) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.padding_index) def forward(self, x): return self.model(x) def teacher_forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Uses provided sequence y as guide for non-autoregressive encoding-decoding of x. Parameters ---------- x Batch of images to be encoded. See self.model.encode for shape information. y Batch of ground truth output sequences. Returns ------- torch.Tensor (B, C, Sy) logits """ x = self.model.encode(x) output = self.model.decode(x, y) # (Sy, B, C) return output.permute(1, 2, 0) # (B, C, Sy) def training_step(self, batch, batch_idx): x, y = batch logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("train/loss", loss) outputs = {"loss": loss} if self.is_logged_batch(): preds = self.get_preds(logits) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) outputs.update({"pred_strs": pred_strs, "gt_strs": gt_strs}) return outputs def validation_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("validation/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("validation/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def test_step(self, batch, batch_idx): x, y = batch # compute loss as in training, for comparison logits = self.teacher_forward(x, y[:, :-1]) loss = self.loss_fn(logits, y[:, 1:]) self.log("test/loss", loss, prog_bar=True, sync_dist=True) outputs = {"loss": loss} # compute predictions as in production, for comparison preds = self(x) self.val_cer(preds, y) self.log("test/cer", self.val_cer, prog_bar=True, sync_dist=True) pred_strs, gt_strs = self.batchmap(preds), self.batchmap(y) self.add_on_first_batch({"pred_strs": pred_strs, "gt_strs": gt_strs}, outputs, batch_idx) self.add_on_first_batch({"logits": logits.detach()}, outputs, batch_idx) return outputs def map(self, ks: Sequence[int], ignore: bool = True) -> str: """Maps an iterable of integers to a string using the lit model's mapping.""" if ignore: return "".join([self.mapping[k] for k in ks if k not in self.ignore_tokens]) else: return "".join([self.mapping[k] for k in ks]) def batchmap(self, ks: Sequence[Sequence[int]], ignore=True) -> List[str]: """Maps a list of lists of integers to a list of strings using the lit model's mapping.""" return [self.map(k, ignore) for k in ks] def get_preds(self, logitlikes: torch.Tensor, replace_after_end: bool = True) -> torch.Tensor: """Converts logit-like Tensors into prediction indices, optionally overwritten after end token index. Parameters ---------- logitlikes (B, C, Sy) Tensor with classes as second dimension. The largest value is the one whose index we will return. Logits, logprobs, and probs are all acceptable. replace_after_end Whether to replace values after the first appearance of the end token with the padding token. Returns ------- torch.Tensor (B, Sy) Tensor of integers in [0, C-1] representing predictions. """ raw = torch.argmax(logitlikes, dim=1) # (B, C, Sy) -> (B, Sy) if replace_after_end: return replace_after(raw, self.end_index, self.padding_index) # (B, Sy) else: return raw # (B, Sy) ================================================ FILE: lab08/text_recognizer/lit_models/util.py ================================================ from typing import Union import torch def first_appearance(x: torch.Tensor, element: Union[int, float], dim: int = 1) -> torch.Tensor: """Return indices of first appearance of element in x, collapsing along dim. Based on https://discuss.pytorch.org/t/first-nonzero-index/24769/9 Parameters ---------- x One or two-dimensional Tensor to search for element. element Item to search for inside x. dim Dimension of Tensor to collapse over. Returns ------- torch.Tensor Indices where element occurs in x. If element is not found, return length of x along dim. One dimension smaller than x. Raises ------ ValueError if x is not a 1 or 2 dimensional Tensor Examples -------- >>> first_appearance(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3) tensor([2, 1, 3, 0]) >>> first_appearance(torch.tensor([1, 2, 3]), 1, dim=0) tensor(0) """ if x.dim() > 2 or x.dim() == 0: raise ValueError(f"only 1 or 2 dimensional Tensors allowed, got Tensor with dim {x.dim()}") matches = x == element first_appearance_mask = (matches.cumsum(dim) == 1) & matches does_match, match_index = first_appearance_mask.max(dim) first_inds = torch.where(does_match, match_index, x.shape[dim]) return first_inds def replace_after(x: torch.Tensor, element: Union[int, float], replace: Union[int, float]) -> torch.Tensor: """Replace all values in each row of 2d Tensor x after the first appearance of element with replace. Parameters ---------- x Two-dimensional Tensor (shape denoted (B, S)) to replace values in. element Item to search for inside x. replace Item that replaces entries that appear after element. Returns ------- outs New Tensor of same shape as x with values after element replaced. Examples -------- >>> replace_after(torch.tensor([[1, 2, 3], [2, 3, 3], [1, 1, 1], [3, 1, 1]]), 3, 4) tensor([[1, 2, 3], [2, 3, 4], [1, 1, 1], [3, 4, 4]]) """ first_appearances = first_appearance(x, element, dim=1) # (B,) indices = torch.arange(0, x.shape[-1]).type_as(x) # (S,) outs = torch.where( indices[None, :] <= first_appearances[:, None], # if index is before first appearance x, # return the value from x replace, # otherwise, return the replacement value ) return outs # (B, S) ================================================ FILE: lab08/text_recognizer/metadata/emnist.py ================================================ from pathlib import Path import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "emnist" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "emnist" PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist" PROCESSED_DATA_FILENAME = PROCESSED_DATA_DIRNAME / "byclass.h5" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_essentials.json" NUM_SPECIAL_TOKENS = 4 INPUT_SHAPE = (28, 28) DIMS = (1, *INPUT_SHAPE) # Extra dimension added by ToTensor() OUTPUT_DIMS = (1,) MAPPING = [ "", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", '"', "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?", ] ================================================ FILE: lab08/text_recognizer/metadata/emnist_lines.py ================================================ from pathlib import Path import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "emnist_lines" ESSENTIALS_FILENAME = Path(__file__).parents[1].resolve() / "data" / "emnist_lines_essentials.json" CHAR_HEIGHT, CHAR_WIDTH = emnist.DIMS[1:3] DIMS = (emnist.DIMS[0], CHAR_HEIGHT, None) # width variable, depends on maximum sequence length MAPPING = emnist.MAPPING ================================================ FILE: lab08/text_recognizer/metadata/iam.py ================================================ import text_recognizer.metadata.shared as shared RAW_DATA_DIRNAME = shared.DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = shared.DATA_DIRNAME / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # if images were downsampled, the regions must also be LINE_REGION_PADDING = 8 # add this many pixels around the exact coordinates ================================================ FILE: lab08/text_recognizer/metadata/iam_lines.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_lines" IMAGE_SCALE_FACTOR = 2 CHAR_WIDTH = emnist.INPUT_SHAPE[0] // IMAGE_SCALE_FACTOR # rough estimate IMAGE_HEIGHT = 112 // IMAGE_SCALE_FACTOR IMAGE_WIDTH = 3072 // IMAGE_SCALE_FACTOR # rounding up IAMLines empirical maximum width DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (89, 1) MAPPING = emnist.MAPPING ================================================ FILE: lab08/text_recognizer/metadata/iam_paragraphs.py ================================================ import text_recognizer.metadata.emnist as emnist import text_recognizer.metadata.shared as shared PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_paragraphs" NEW_LINE_TOKEN = "\n" MAPPING = [*emnist.MAPPING, NEW_LINE_TOKEN] # must match IMAGE_SCALE_FACTOR for IAMLines to be compatible with synthetic paragraphs IMAGE_SCALE_FACTOR = 2 IMAGE_HEIGHT, IMAGE_WIDTH = 576, 640 IMAGE_SHAPE = (IMAGE_HEIGHT, IMAGE_WIDTH) MAX_LABEL_LENGTH = 682 DIMS = (1, IMAGE_HEIGHT, IMAGE_WIDTH) OUTPUT_DIMS = (MAX_LABEL_LENGTH, 1) ================================================ FILE: lab08/text_recognizer/metadata/iam_synthetic_paragraphs.py ================================================ import text_recognizer.metadata.iam_paragraphs as iam_paragraphs import text_recognizer.metadata.shared as shared NEW_LINE_TOKEN = iam_paragraphs.NEW_LINE_TOKEN PROCESSED_DATA_DIRNAME = shared.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs" EXPECTED_BATCH_SIZE = 64 EXPECTED_GPUS = 8 EXPECTED_STEPS = 40 # set the dataset's length based on parameters during typical training DATASET_LEN = EXPECTED_BATCH_SIZE * EXPECTED_GPUS * EXPECTED_STEPS ================================================ FILE: lab08/text_recognizer/metadata/mnist.py ================================================ """Metadata for the MNIST dataset.""" import text_recognizer.metadata.shared as shared DOWNLOADED_DATA_DIRNAME = shared.DOWNLOADED_DATA_DIRNAME DIMS = (1, 28, 28) OUTPUT_DIMS = (1,) MAPPING = list(range(10)) TRAIN_SIZE = 55000 VAL_SIZE = 5000 ================================================ FILE: lab08/text_recognizer/metadata/shared.py ================================================ from pathlib import Path DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" DOWNLOADED_DATA_DIRNAME = DATA_DIRNAME / "downloaded" ================================================ FILE: lab08/text_recognizer/models/__init__.py ================================================ """Models for character and text recognition in images.""" from .mlp import MLP from .cnn import CNN from .line_cnn_simple import LineCNNSimple from .resnet_transformer import ResnetTransformer from .line_cnn_transformer import LineCNNTransformer ================================================ FILE: lab08/text_recognizer/models/cnn.py ================================================ """Basic convolutional model building blocks.""" import argparse from typing import Any, Dict import torch from torch import nn import torch.nn.functional as F CONV_DIM = 64 FC_DIM = 128 FC_DROPOUT = 0.25 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__(self, input_channels: int, output_channels: int) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class CNN(nn.Module): """Simple CNN for recognizing characters in a square image.""" def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_channels, input_height, input_width = self.data_config["input_dims"] assert ( input_height == input_width ), f"input height and width should be equal, but was {input_height}, {input_width}" self.input_height, self.input_width = input_height, input_width num_classes = len(self.data_config["mapping"]) conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.conv1 = ConvBlock(input_channels, conv_dim) self.conv2 = ConvBlock(conv_dim, conv_dim) self.dropout = nn.Dropout(fc_dropout) self.max_pool = nn.MaxPool2d(2) # Because our 3x3 convs have padding size 1, they leave the input size unchanged. # The 2x2 max-pool divides the input size by 2. conv_output_height, conv_output_width = input_height // 2, input_width // 2 self.fc_input_dim = int(conv_output_height * conv_output_width * conv_dim) self.fc1 = nn.Linear(self.fc_input_dim, fc_dim) self.fc2 = nn.Linear(fc_dim, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the CNN to x. Parameters ---------- x (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config. Returns ------- torch.Tensor (B, Cl) tensor """ _B, _Ch, H, W = x.shape assert H == self.input_height and W == self.input_width, f"bad inputs to CNN with shape {x.shape}" x = self.conv1(x) # _B, CONV_DIM, H, W x = self.conv2(x) # _B, CONV_DIM, H, W x = self.max_pool(x) # _B, CONV_DIM, H // 2, W // 2 x = self.dropout(x) x = torch.flatten(x, 1) # _B, CONV_DIM * H // 2 * W // 2 x = self.fc1(x) # _B, FC_DIM x = F.relu(x) x = self.fc2(x) # _B, Cl return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab08/text_recognizer/models/line_cnn.py ================================================ """Basic building blocks for convolutional models over lines of text.""" import argparse import math from typing import Any, Dict, Tuple, Union import torch from torch import nn import torch.nn.functional as F # Common type hints Param2D = Union[int, Tuple[int, int]] CONV_DIM = 32 FC_DIM = 512 FC_DROPOUT = 0.2 WINDOW_WIDTH = 16 WINDOW_STRIDE = 8 class ConvBlock(nn.Module): """ Simple 3x3 conv with padding size 1 (to leave the input size unchanged), followed by a ReLU. """ def __init__( self, input_channels: int, output_channels: int, kernel_size: Param2D = 3, stride: Param2D = 1, padding: Param2D = 1, ) -> None: super().__init__() self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the ConvBlock to x. Parameters ---------- x (B, C, H, W) tensor Returns ------- torch.Tensor (B, C, H, W) tensor """ c = self.conv(x) r = self.relu(c) return r class LineCNN(nn.Module): """ Model that uses a simple CNN to process an image of a line of characters with a window, outputs a sequence of logits """ def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.args = vars(args) if args is not None else {} self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] _C, H, _W = data_config["input_dims"] conv_dim = self.args.get("conv_dim", CONV_DIM) fc_dim = self.args.get("fc_dim", FC_DIM) fc_dropout = self.args.get("fc_dropout", FC_DROPOUT) self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) # Input is (1, H, W) self.convs = nn.Sequential( ConvBlock(1, conv_dim), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim, stride=2), ConvBlock(conv_dim, conv_dim), ConvBlock(conv_dim, conv_dim * 2, stride=2), ConvBlock(conv_dim * 2, conv_dim * 2), ConvBlock(conv_dim * 2, conv_dim * 4, stride=2), ConvBlock(conv_dim * 4, conv_dim * 4), ConvBlock( conv_dim * 4, fc_dim, kernel_size=(H // 8, self.WW // 8), stride=(H // 8, self.WS // 8), padding=0 ), ) self.fc1 = nn.Linear(fc_dim, fc_dim) self.dropout = nn.Dropout(fc_dropout) self.fc2 = nn.Linear(fc_dim, self.num_classes) self._init_weights() def _init_weights(self): """ Initialize weights in a better way than default. See https://github.com/pytorch/pytorch/issues/18182 """ for m in self.modules(): if type(m) in { nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear, }: nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_out", nonlinearity="relu") if m.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(m.bias, -bound, bound) def forward(self, x: torch.Tensor) -> torch.Tensor: """Applies the LineCNN to a black-and-white input image. Parameters ---------- x (B, 1, H, W) input image Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and self.window_width C is self.num_classes """ _B, _C, _H, _W = x.shape x = self.convs(x) # (B, FC_DIM, 1, Sx) x = x.squeeze(2).permute(0, 2, 1) # (B, S, FC_DIM) x = F.relu(self.fc1(x)) # -> (B, S, FC_DIM) x = self.dropout(x) x = self.fc2(x) # (B, S, C) x = x.permute(0, 2, 1) # -> (B, C, S) if self.limit_output_length: x = x[:, :, : self.output_length] return x @staticmethod def add_to_argparse(parser): parser.add_argument("--conv_dim", type=int, default=CONV_DIM) parser.add_argument("--fc_dim", type=int, default=FC_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab08/text_recognizer/models/line_cnn_simple.py ================================================ """Simplest version of LineCNN that works on cleanly-separated characters.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .cnn import CNN IMAGE_SIZE = 28 WINDOW_WIDTH = IMAGE_SIZE WINDOW_STRIDE = IMAGE_SIZE class LineCNNSimple(nn.Module): """LeNet based model that takes a line of width that is a multiple of CHAR_WIDTH.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config self.WW = self.args.get("window_width", WINDOW_WIDTH) self.WS = self.args.get("window_stride", WINDOW_STRIDE) self.limit_output_length = self.args.get("limit_output_length", False) self.num_classes = len(data_config["mapping"]) self.output_length = data_config["output_dims"][0] cnn_input_dims = (data_config["input_dims"][0], self.WW, self.WW) cnn_data_config = {**data_config, **{"input_dims": cnn_input_dims}} self.cnn = CNN(data_config=cnn_data_config, args=args) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the LineCNN to an input image and return logits. Parameters ---------- x (B, C, H, W) input image with H equal to IMAGE_SIZE Returns ------- torch.Tensor (B, C, S) logits, where S is the length of the sequence and C is the number of classes S can be computed from W and CHAR_WIDTH C is self.num_classes """ B, _C, H, W = x.shape assert H == IMAGE_SIZE # Make sure we can use our CNN class # Compute number of windows S = math.floor((W - self.WW) / self.WS + 1) # NOTE: type_as properly sets device activations = torch.zeros((B, self.num_classes, S)).type_as(x) for s in range(S): start_w = self.WS * s end_w = start_w + self.WW window = x[:, :, :, start_w:end_w] # -> (B, C, H, self.WW) activations[:, :, s] = self.cnn(window) if self.limit_output_length: # S might not match ground truth, so let's only take enough activations as are expected activations = activations[:, :, : self.output_length] return activations @staticmethod def add_to_argparse(parser): CNN.add_to_argparse(parser) parser.add_argument( "--window_width", type=int, default=WINDOW_WIDTH, help="Width of the window that will slide over the input image.", ) parser.add_argument( "--window_stride", type=int, default=WINDOW_STRIDE, help="Stride of the window that will slide over the input image.", ) parser.add_argument("--limit_output_length", action="store_true", default=False) return parser ================================================ FILE: lab08/text_recognizer/models/line_cnn_transformer.py ================================================ """Model that combines a LineCNN with a Transformer model for text prediction.""" import argparse import math from typing import Any, Dict import torch from torch import nn from .line_cnn import LineCNN from .transformer_util import generate_square_subsequent_mask, PositionalEncoding TF_DIM = 256 TF_FC_DIM = 256 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 class LineCNNTransformer(nn.Module): """Process the line through a CNN and process the resulting sequence with a Transformer decoder.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # Instantiate LineCNN with "num_classes" set to self.dim data_config_for_line_cnn = {**data_config} data_config_for_line_cnn["mapping"] = list(range(self.dim)) self.line_cnn = LineCNN(data_config=data_config_for_line_cnn, args=args) # LineCNN outputs (B, E, S) log probs, with E == dim self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.pos_encoder = PositionalEncoding(d_model=self.dim) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (Sx, B, E) logits """ x = self.line_cnn(x) # (B, E, Sx) x = x * math.sqrt(self.dim) x = x.permute(2, 0, 1) # (Sx, B, E) x = self.pos_encoder(x) # (Sx, B, E) return x def decode(self, x, y): """Decode a batch of encoded images x using preceding ground truth y. Parameters ---------- x (Sx, B, E) image encoded as a sequence y (B, Sy) with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) logits """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output def forward(self, x: torch.Tensor) -> torch.Tensor: """Predict sequences of tokens from input images auto-regressively. Parameters ---------- x (B, H, W) image Returns ------- torch.Tensor (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1:] # Set the last output token # Set all tokens after end token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) @staticmethod def add_to_argparse(parser): LineCNN.add_to_argparse(parser) parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_FC_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab08/text_recognizer/models/mlp.py ================================================ import argparse from typing import Any, Dict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F FC1_DIM = 1024 FC2_DIM = 128 FC_DROPOUT = 0.5 class MLP(nn.Module): """Simple MLP suitable for recognizing single characters.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.args = vars(args) if args is not None else {} self.data_config = data_config input_dim = np.prod(self.data_config["input_dims"]) num_classes = len(self.data_config["mapping"]) fc1_dim = self.args.get("fc1", FC1_DIM) fc2_dim = self.args.get("fc2", FC2_DIM) dropout_p = self.args.get("fc_dropout", FC_DROPOUT) self.fc1 = nn.Linear(input_dim, fc1_dim) self.dropout = nn.Dropout(dropout_p) self.fc2 = nn.Linear(fc1_dim, fc2_dim) self.fc3 = nn.Linear(fc2_dim, num_classes) def forward(self, x): x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) return x @staticmethod def add_to_argparse(parser): parser.add_argument("--fc1", type=int, default=FC1_DIM) parser.add_argument("--fc2", type=int, default=FC2_DIM) parser.add_argument("--fc_dropout", type=float, default=FC_DROPOUT) return parser ================================================ FILE: lab08/text_recognizer/models/resnet_transformer.py ================================================ """Model combining a ResNet with a Transformer for image-to-sequence tasks.""" import argparse import math from typing import Any, Dict import torch from torch import nn import torchvision from .transformer_util import generate_square_subsequent_mask, PositionalEncoding, PositionalEncodingImage TF_DIM = 256 TF_FC_DIM = 1024 TF_DROPOUT = 0.4 TF_LAYERS = 4 TF_NHEAD = 4 RESNET_DIM = 512 # hard-coded class ResnetTransformer(nn.Module): """Pass an image through a Resnet and decode the resulting embedding with a Transformer.""" def __init__( self, data_config: Dict[str, Any], args: argparse.Namespace = None, ) -> None: super().__init__() self.data_config = data_config self.input_dims = data_config["input_dims"] self.num_classes = len(data_config["mapping"]) self.mapping = data_config["mapping"] inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} self.start_token = inverse_mapping[""] self.end_token = inverse_mapping[""] self.padding_token = inverse_mapping["

"] self.max_output_length = data_config["output_dims"][0] self.args = vars(args) if args is not None else {} self.dim = self.args.get("tf_dim", TF_DIM) tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) tf_nhead = self.args.get("tf_nhead", TF_NHEAD) tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) tf_layers = self.args.get("tf_layers", TF_LAYERS) # ## Encoder part - should output vector sequence of length self.dim per sample resnet = torchvision.models.resnet18(weights=None) self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) # encoder_projection will output (B, dim, _H, _W) logits self.enc_pos_encoder = PositionalEncodingImage( d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] ) # Max (Ho, Wo) # ## Decoder part self.embedding = nn.Embedding(self.num_classes, self.dim) self.fc = nn.Linear(self.dim, self.num_classes) self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) self.y_mask = generate_square_subsequent_mask(self.max_output_length) self.transformer_decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), num_layers=tf_layers, ) self.init_weights() # This is empirically important def forward(self, x: torch.Tensor) -> torch.Tensor: """Autoregressively produce sequences of labels from input images. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- output_tokens (B, Sy) with elements in [0, C-1] where C is num_classes """ B = x.shape[0] S = self.max_output_length x = self.encode(x) # (Sx, B, E) output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, Sy) output_tokens[:, 0] = self.start_token # Set start token for Sy in range(1, S): y = output_tokens[:, :Sy] # (B, Sy) output = self.decode(x, y) # (Sy, B, C) output = torch.argmax(output, dim=-1) # (Sy, B) output_tokens[:, Sy] = output[-1] # Set the last output token # Early stopping of prediction loop to speed up prediction if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): break # Set all tokens after end or padding token to be padding for Sy in range(1, S): ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) output_tokens[ind, Sy] = self.padding_token return output_tokens # (B, Sy) def init_weights(self): initrange = 0.1 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() self.fc.weight.data.uniform_(-initrange, initrange) nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") if self.encoder_projection.bias is not None: _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.encoder_projection.weight.data) bound = 1 / math.sqrt(fan_out) nn.init.normal_(self.encoder_projection.bias, -bound, bound) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode each image tensor in a batch into a sequence of embeddings. Parameters ---------- x (B, Ch, H, W) image, where Ch == 1 or Ch == 3 Returns ------- (Sx, B, E) sequence of embeddings, going left-to-right, top-to-bottom from final ResNet feature maps """ _B, C, _H, _W = x.shape if C == 1: x = x.repeat(1, 3, 1, 1) x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo return x def decode(self, x, y): """Decode a batch of encoded images x with guiding sequences y. During autoregressive inference, the guiding sequence will be previous predictions. During training, the guiding sequence will be the ground truth. Parameters ---------- x (Sx, B, E) images encoded as sequences of embeddings y (B, Sy) guiding sequences with elements in [0, C-1] where C is num_classes Returns ------- torch.Tensor (Sy, B, C) batch of logit sequences """ y_padding_mask = y == self.padding_token y = y.permute(1, 0) # (Sy, B) y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) y = self.dec_pos_encoder(y) # (Sy, B, E) Sy = y.shape[0] y_mask = self.y_mask[:Sy, :Sy].type_as(x) output = self.transformer_decoder( tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask ) # (Sy, B, E) output = self.fc(output) # (Sy, B, C) return output @staticmethod def add_to_argparse(parser): parser.add_argument("--tf_dim", type=int, default=TF_DIM) parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) return parser ================================================ FILE: lab08/text_recognizer/models/transformer_util.py ================================================ """Position Encoding and other utilities for Transformers.""" import math import torch from torch import Tensor import torch.nn as nn class PositionalEncodingImage(nn.Module): """ Module used to add 2-D positional encodings to the feature-map produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000, persistent: bool = False) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor: pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """pytorch.nn.module.forward""" # x.shape = (B, d_model, H, W) assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding(torch.nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, persistent: bool = False) -> None: super().__init__() self.dropout = torch.nn.Dropout(p=dropout) pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model) self.register_buffer( "pe", pe, persistent=persistent ) # not necessary to persist in state_dict, since it can be remade @staticmethod def make_pe(d_model: int, max_len: int) -> torch.Tensor: pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: # x.shape = (S, B, d_model) assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) def generate_square_subsequent_mask(size: int) -> torch.Tensor: """Generate a triangular (size, size) mask.""" mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask ================================================ FILE: lab08/text_recognizer/paragraph_text_recognizer.py ================================================ """Detects a paragraph of text in an input image. Example usage as a script: python text_recognizer/paragraph_text_recognizer.py \ text_recognizer/tests/support/paragraphs/a01-077.png python text_recognizer/paragraph_text_recognizer.py \ https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png """ import argparse from pathlib import Path from typing import Sequence, Union from PIL import Image import torch from text_recognizer import util from text_recognizer.stems.paragraph import ParagraphStem STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "paragraph-text-recognizer" MODEL_FILE = "model.pt" class ParagraphTextRecognizer: """Recognizes a paragraph of text in an image.""" def __init__(self, model_path=None): if model_path is None: model_path = STAGED_MODEL_DIRNAME / MODEL_FILE self.model = torch.jit.load(model_path) self.mapping = self.model.mapping self.ignore_tokens = self.model.ignore_tokens self.stem = ParagraphStem() @torch.no_grad() def predict(self, image: Union[str, Path, Image.Image]) -> str: """Predict/infer text in input image (which can be a file path or url).""" image_pil = image if not isinstance(image, Image.Image): image_pil = util.read_image_pil(image, grayscale=True) image_tensor = self.stem(image_pil).unsqueeze(axis=0) y_pred = self.model(image_tensor)[0] pred_str = convert_y_label_to_string(y=y_pred, mapping=self.mapping, ignore_tokens=self.ignore_tokens) return pred_str def convert_y_label_to_string(y: torch.Tensor, mapping: Sequence[str], ignore_tokens: Sequence[int]) -> str: return "".join([mapping[i] for i in y if i not in ignore_tokens]) def main(): parser = argparse.ArgumentParser(description=__doc__.split("\n")[0]) parser.add_argument( "filename", type=str, help="Name for an image file. This can be a local path, a URL, a URI from AWS/GCP/Azure storage, an HDFS path, or any other resource locator supported by the smart_open library.", ) args = parser.parse_args() text_recognizer = ParagraphTextRecognizer() pred_str = text_recognizer.predict(args.filename) print(pred_str) if __name__ == "__main__": main() ================================================ FILE: lab08/text_recognizer/stems/image.py ================================================ import torch from torchvision import transforms class ImageStem: """A stem for models operating on images. Images are presumed to be provided as PIL images, as is standard for torchvision Datasets. Transforms are split into two categories: pil_transforms, which take in and return PIL images, and torch_transforms, which take in and return Torch tensors. By default, these two transforms are both identities. In between, the images are mapped to tensors. The torch_transforms are wrapped in a torch.nn.Sequential and so are compatible with torchscript if the underyling Modules are compatible. """ def __init__(self): self.pil_transforms = transforms.Compose([]) self.pil_to_tensor = transforms.ToTensor() self.torch_transforms = torch.nn.Sequential() def __call__(self, img): img = self.pil_transforms(img) img = self.pil_to_tensor(img) with torch.no_grad(): img = self.torch_transforms(img) return img class MNISTStem(ImageStem): """A stem for handling images from the MNIST dataset.""" def __init__(self): super().__init__() self.torch_transforms = torch.nn.Sequential(transforms.Normalize((0.1307,), (0.3081,))) ================================================ FILE: lab08/text_recognizer/stems/line.py ================================================ import random from PIL import Image from torchvision import transforms import text_recognizer.metadata.iam_lines as metadata from text_recognizer.stems.image import ImageStem class LineStem(ImageStem): """A stem for handling images containing a line of text.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.5, 1)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "translate": (0, 0.05), "scale": (0.4, 1.1), "shear": (-40, 50), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } if augment: self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] ) class IAMLineStem(ImageStem): """A stem for handling images containing lines of text from the IAMLines dataset.""" def __init__(self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None): super().__init__() def embed_crop(crop, augment=augment): # crop is PIL.image of dtype="L" (so values range from 0 -> 255) image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) # Resize crop crop_width, crop_height = crop.size new_crop_height = metadata.IMAGE_HEIGHT new_crop_width = int(new_crop_height * (crop_width / crop_height)) if augment: # Add random stretching new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) crop_resized = crop.resize((new_crop_width, new_crop_height), resample=Image.BILINEAR) # Embed in the image x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) y = metadata.IMAGE_HEIGHT - new_crop_height image.paste(crop_resized, (x, y)) return image if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": (0.8, 1.6)} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 1, "shear": (-30, 20), "interpolation": transforms.InterpolationMode.BILINEAR, "fill": 0, } pil_transforms_list = [transforms.Lambda(embed_crop)] if augment: pil_transforms_list += [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomAffine(**random_affine_kwargs), ] self.pil_transforms = transforms.Compose(pil_transforms_list) ================================================ FILE: lab08/text_recognizer/stems/paragraph.py ================================================ """IAMParagraphs Stem class.""" import torchvision.transforms as transforms import text_recognizer.metadata.iam_paragraphs as metadata from text_recognizer.stems.image import ImageStem IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH IMAGE_SHAPE = metadata.IMAGE_SHAPE MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH class ParagraphStem(ImageStem): """A stem for handling images that contain a paragraph of text.""" def __init__( self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None, random_perspective_kwargs=None, gaussian_blur_kwargs=None, sharpness_kwargs=None, ): super().__init__() if not augment: self.pil_transforms = transforms.Compose([transforms.CenterCrop(IMAGE_SHAPE)]) else: if color_jitter_kwargs is None: color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} if random_affine_kwargs is None: random_affine_kwargs = { "degrees": 3, "shear": 6, "scale": (0.95, 1), "interpolation": transforms.InterpolationMode.BILINEAR, } if random_perspective_kwargs is None: random_perspective_kwargs = { "distortion_scale": 0.2, "p": 0.5, "interpolation": transforms.InterpolationMode.BILINEAR, } if gaussian_blur_kwargs is None: gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} if sharpness_kwargs is None: sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} # IMAGE_SHAPE is (576, 640) self.pil_transforms = transforms.Compose( [ transforms.ColorJitter(**color_jitter_kwargs), transforms.RandomCrop( size=IMAGE_SHAPE, padding=None, pad_if_needed=True, fill=0, padding_mode="constant" ), transforms.RandomAffine(**random_affine_kwargs), transforms.RandomPerspective(**random_perspective_kwargs), transforms.GaussianBlur(**gaussian_blur_kwargs), transforms.RandomAdjustSharpness(**sharpness_kwargs), ] ) ================================================ FILE: lab08/text_recognizer/tests/test_callback_utils.py ================================================ """Tests for the text_recognizer.callbacks.util module.""" import random import string import tempfile import pytorch_lightning as pl from text_recognizer.callbacks.util import check_and_warn def test_check_and_warn_simple(): """Test the success and failure in the case of a simple class we control.""" class Foo: pass # a class with no special attributes letters = string.ascii_lowercase random_attribute = "".join(random.choices(letters, k=10)) assert check_and_warn(Foo(), random_attribute, "random feature") assert not check_and_warn(Foo(), "__doc__", "feature of all Python objects") def test_check_and_warn_tblogger(): """Test that we return a truthy value when trying to log tables with TensorBoard. We added check_and_warn in order to prevent a crash if this happens. """ tblogger = pl.loggers.TensorBoardLogger(save_dir=tempfile.TemporaryDirectory()) assert check_and_warn(tblogger, "log_table", "tables") def test_check_and_warn_wandblogger(): """Test that we return a falsy value when we try to log tables with W&B. In adding check_and_warn, we don't want to block the feature in the happy path. """ wandblogger = pl.loggers.WandbLogger(anonymous=True) assert not check_and_warn(wandblogger, "log_table", "tables") ================================================ FILE: lab08/text_recognizer/tests/test_iam.py ================================================ """Test for data.iam module.""" from text_recognizer.data.iam import IAM def test_iam_parsed_lines(): """Tests that we retrieve the same number of line labels and line image cropregions.""" iam = IAM() iam.prepare_data() for iam_id in iam.all_ids: assert len(iam.line_strings_by_id[iam_id]) == len(iam.line_regions_by_id[iam_id]) def test_iam_data_splits(): """Fails when any identifiers are shared between training, test, or validation.""" iam = IAM() iam.prepare_data() assert not set(iam.train_ids) & set(iam.validation_ids) assert not set(iam.train_ids) & set(iam.test_ids) assert not set(iam.validation_ids) & set(iam.test_ids) ================================================ FILE: lab08/text_recognizer/util.py ================================================ """Utility functions for text_recognizer module.""" import base64 import contextlib import hashlib from io import BytesIO import os from pathlib import Path from typing import Union from urllib.request import urlretrieve import numpy as np from PIL import Image import smart_open from tqdm import tqdm def to_categorical(y, num_classes): """1-hot encode a tensor.""" return np.eye(num_classes, dtype="uint8")[y] def read_image_pil(image_uri: Union[Path, str], grayscale=False) -> Image: with smart_open.open(image_uri, "rb") as image_file: return read_image_pil_file(image_file, grayscale) def read_image_pil_file(image_file, grayscale=False) -> Image: with Image.open(image_file) as image: if grayscale: image = image.convert(mode="L") else: image = image.convert(mode=image.mode) return image @contextlib.contextmanager def temporary_working_directory(working_dir: Union[str, Path]): """Temporarily switches to a directory, then returns to the original directory on exit.""" curdir = os.getcwd() os.chdir(working_dir) try: yield finally: os.chdir(curdir) def read_b64_image(b64_string, grayscale=False): """Load base64-encoded images.""" try: image_file = read_b64_string(b64_string) return read_image_pil_file(image_file, grayscale) except Exception as exception: raise ValueError("Could not load image from b64 {}: {}".format(b64_string, exception)) from exception def read_b64_string(b64_string, return_data_type=False): """Read a base64-encoded string into an in-memory file-like object.""" data_header, b64_data = split_and_validate_b64_string(b64_string) b64_buffer = BytesIO(base64.b64decode(b64_data)) if return_data_type: return get_b64_filetype(data_header), b64_buffer else: return b64_buffer def get_b64_filetype(data_header): """Retrieves the filetype information from the data type header of a base64-encoded object.""" _, file_type = data_header.split("/") return file_type def split_and_validate_b64_string(b64_string): """Return the data_type and data of a b64 string, with validation.""" header, data = b64_string.split(",", 1) assert header.startswith("data:") assert header.endswith(";base64") data_type = header.split(";")[0].split(":")[1] return data_type, data def encode_b64_image(image, format="png"): """Encode a PIL image as a base64 string.""" _buffer = BytesIO() # bytes that live in memory image.save(_buffer, format=format) # but which we write to like a file encoded_image = base64.b64encode(_buffer.getvalue()).decode("utf8") return encoded_image def compute_sha256(filename: Union[Path, str]): """Return SHA256 checksum of a file.""" with open(filename, "rb") as f: return hashlib.sha256(f.read()).hexdigest() class TqdmUpTo(tqdm): """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py""" def update_to(self, blocks=1, bsize=1, tsize=None): """ Parameters ---------- blocks: int, optional Number of blocks transferred so far [default: 1]. bsize: int, optional Size of each block (in tqdm units) [default: 1]. tsize: int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: self.total = tsize self.update(blocks * bsize - self.n) # will also set self.n = b * bsize def download_url(url, filename): """Download a file from url to filename, with a progress bar.""" with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: urlretrieve(url, filename, reporthook=t.update_to, data=None) # noqa: S310 ================================================ FILE: lab08/training/__init__.py ================================================ ================================================ FILE: lab08/training/cleanup_artifacts.py ================================================ """Removes artifacts from projects and runs. Artifacts are binary files that we want to track and version but don't want to include in git, generally because they are too large, because they don't have meaningful diffs, or because they change more quickly than code. During development, we often generate artifacts that we don't really need, e.g. model weights for an overfitting test run. Space on artifact storage is generally very large, but it is limited, so we should occasionally delete unneeded artifacts to reclaim some of that space. For usage help, run python training/cleanup_artifacts.py --help """ import argparse import wandb api = wandb.Api() DEFAULT_PROJECT = "fsdl-text-recognizer-2022-training" DEFAULT_ENTITY = api.default_entity def _setup_parser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--entity", type=str, default=None, help="The entity from which to remove artifacts. Provide the value DEFAULT " + f"to use the default WANDB_ENTITY, which is currently {DEFAULT_ENTITY}.", ) parser.add_argument( "--project", type=str, default=DEFAULT_PROJECT, help=f"The project from which to remove artifacts. Default is {DEFAULT_PROJECT}", ) parser.add_argument( "--run_ids", type=str, default=None, nargs="*", help="One or more run IDs from which to remove artifacts. Default is None.", ) parser.add_argument( "--run_name_res", type=str, default=None, nargs="*", help="One or more regular expressions to use to select runs (by display name) from which to remove artifacts. See wandb.Api.runs documentation for details on the syntax. Beware that this is a footgun and consider using interactively with --dryrun and -v. Default is None.", metavar="RUN_NAME_REGEX", ) flags = parser.add_mutually_exclusive_group() flags.add_argument("--all", action="store_true", help="Delete all artifacts from selected runs.") flags.add_argument( "--no-alias", action="store_true", help="Delete all artifacts without an alias from selected runs." ) flags.add_argument( "--aliases", type=str, nargs="*", help="Delete artifacts that have any of the aliases from the provided list from selected runs.", ) parser.add_argument( "-v", action="store_true", dest="verbose", help="Display information about targeted entities, projects, runs, and artifacts.", ) parser.add_argument( "--dryrun", action="store_true", help="Select artifacts without deleting them and display which artifacts were selected.", ) return parser def main(args): entity = _get_entity_from(args) project_path = f"{entity}/{args.project}" runs = _get_runs(project_path, args.run_ids, args.run_name_res, verbose=args.verbose) artifact_selector = _get_selector_from(args) protect_aliases = args.no_alias # avoid deletion of any aliased artifacts for run in runs: clean_run_artifacts( run, selector=artifact_selector, protect_aliases=protect_aliases, verbose=args.verbose, dryrun=args.dryrun ) def clean_run_artifacts(run, selector, protect_aliases=True, verbose=False, dryrun=True): artifacts = run.logged_artifacts() for artifact in artifacts: if selector(artifact): remove_artifact(artifact, protect_aliases=protect_aliases, verbose=verbose, dryrun=dryrun) def remove_artifact(artifact, protect_aliases, verbose=False, dryrun=True): project, entity, id = artifact.project, artifact.entity, artifact.id type, aliases = artifact.type, artifact.aliases if verbose or dryrun: print(f"selecting for deletion artifact {project}/{entity}/{id} of type {type} with aliases {aliases}") if not dryrun: artifact.delete(delete_aliases=not protect_aliases) def _get_runs(project_path, run_ids=None, run_name_res=None, verbose=False): if run_ids is None: run_ids = [] if run_name_res is None: run_name_res = [] runs = [] for run_id in run_ids: runs.append(_get_run_by_id(project_path, run_id, verbose=verbose)) for run_name_re in run_name_res: runs += _get_runs_by_name_re(project_path, run_name_re, verbose=verbose) return runs def _get_run_by_id(project_path, run_id, verbose=False): path = f"{project_path}/{run_id}" run = api.run(path) if verbose: print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}") return run def _get_runs_by_name_re(project_path, run_name_re, verbose=False): matching_runs = api.runs(path=project_path, filters={"display_name": {"$regex": run_name_re}}) if verbose: for run in matching_runs: print(f"selecting run {run.entity}/{run.project}/{run.id} with display name {run.name}") return matching_runs def _get_selector_from(args, verbose=False): if args.all: if verbose: print("removing all artifacts from matching runs") return lambda _: True if args.no_alias: if verbose: print("removing all artifacts with no aliases from matching runs") return lambda artifact: artifact.aliases == [] if args.aliases: if verbose: print(f"removing all artifacts with any of {args.aliases} in aliases from matching runs") return lambda artifact: any(alias in artifact.aliases for alias in args.aliases) if verbose: print("removing no artifacts matching runs") return lambda _: False def _get_entity_from(args, verbose=False): entity = args.entity if entity is None: raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.") elif entity == "DEFAULT": entity = DEFAULT_ENTITY if verbose: print(f"using default entity {entity}") else: if verbose: print(f"using entity {entity}") return entity if __name__ == "__main__": parser = _setup_parser() args = parser.parse_args() main(args) ================================================ FILE: lab08/training/run_experiment.py ================================================ """Experiment-running framework.""" import argparse from pathlib import Path import numpy as np import pytorch_lightning as pl from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only import torch from text_recognizer import callbacks as cb from text_recognizer import lit_models from training.util import DATA_CLASS_MODULE, import_class, MODEL_CLASS_MODULE, setup_data_and_model_from_args # In order to ensure reproducible experiments, we must set random seeds. np.random.seed(42) torch.manual_seed(42) def _setup_parser(): """Set up Python's ArgumentParser with data, model, trainer, and other arguments.""" parser = argparse.ArgumentParser(add_help=False) # Add Trainer specific arguments, such as --max_epochs, --gpus, --precision trainer_parser = pl.Trainer.add_argparse_args(parser) trainer_parser._action_groups[1].title = "Trainer Args" parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser]) parser.set_defaults(max_epochs=1) # Basic arguments parser.add_argument( "--wandb", action="store_true", default=False, help="If passed, logs experiment results to Weights & Biases. Otherwise logs only to local Tensorboard.", ) parser.add_argument( "--profile", action="store_true", default=False, help="If passed, uses the PyTorch Profiler to track computation, exported as a Chrome-style trace.", ) parser.add_argument( "--data_class", type=str, default="MNIST", help=f"String identifier for the data class, relative to {DATA_CLASS_MODULE}.", ) parser.add_argument( "--model_class", type=str, default="MLP", help=f"String identifier for the model class, relative to {MODEL_CLASS_MODULE}.", ) parser.add_argument( "--load_checkpoint", type=str, default=None, help="If passed, loads a model from the provided path." ) parser.add_argument( "--stop_early", type=int, default=0, help="If non-zero, applies early stopping, with the provided value as the 'patience' argument." + " Default is 0.", ) # Get the data and model classes, so that we can add their specific arguments temp_args, _ = parser.parse_known_args() data_class = import_class(f"{DATA_CLASS_MODULE}.{temp_args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{temp_args.model_class}") # Get data, model, and LitModel specific arguments data_group = parser.add_argument_group("Data Args") data_class.add_to_argparse(data_group) model_group = parser.add_argument_group("Model Args") model_class.add_to_argparse(model_group) lit_model_group = parser.add_argument_group("LitModel Args") lit_models.BaseLitModel.add_to_argparse(lit_model_group) parser.add_argument("--help", "-h", action="help") return parser @rank_zero_only def _ensure_logging_dir(experiment_dir): """Create the logging directory via the rank-zero process, if necessary.""" Path(experiment_dir).mkdir(parents=True, exist_ok=True) def main(): """ Run an experiment. Sample command: ``` python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST ``` For basic help documentation, run the command ``` python training/run_experiment.py --help ``` The available command line args differ depending on some of the arguments, including --model_class and --data_class. To see which command line args are available and read their documentation, provide values for those arguments before invoking --help, like so: ``` python training/run_experiment.py --model_class=MLP --data_class=MNIST --help """ parser = _setup_parser() args = parser.parse_args() data, model = setup_data_and_model_from_args(args) lit_model_class = lit_models.BaseLitModel if args.loss == "transformer": lit_model_class = lit_models.TransformerLitModel if args.load_checkpoint is not None: lit_model = lit_model_class.load_from_checkpoint(args.load_checkpoint, args=args, model=model) else: lit_model = lit_model_class(args=args, model=model) log_dir = Path("training") / "logs" _ensure_logging_dir(log_dir) logger = pl.loggers.TensorBoardLogger(log_dir) experiment_dir = logger.log_dir goldstar_metric = "validation/cer" if args.loss in ("transformer",) else "validation/loss" filename_format = "epoch={epoch:04d}-validation.loss={validation/loss:.3f}" if goldstar_metric == "validation/cer": filename_format += "-validation.cer={validation/cer:.3f}" checkpoint_callback = pl.callbacks.ModelCheckpoint( save_top_k=5, filename=filename_format, monitor=goldstar_metric, mode="min", auto_insert_metric_name=False, dirpath=experiment_dir, every_n_epochs=args.check_val_every_n_epoch, ) summary_callback = pl.callbacks.ModelSummary(max_depth=2) callbacks = [summary_callback, checkpoint_callback] if args.wandb: logger = pl.loggers.WandbLogger(log_model="all", save_dir=str(log_dir), job_type="train") logger.watch(model, log_freq=max(100, args.log_every_n_steps)) logger.log_hyperparams(vars(args)) experiment_dir = logger.experiment.dir callbacks += [cb.ModelSizeLogger(), cb.LearningRateMonitor()] if args.stop_early: early_stopping_callback = pl.callbacks.EarlyStopping( monitor="validation/loss", mode="min", patience=args.stop_early ) callbacks.append(early_stopping_callback) if args.wandb and args.loss in ("transformer",): callbacks.append(cb.ImageToTextLogger()) trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger) if args.profile: sched = torch.profiler.schedule(wait=0, warmup=3, active=4, repeat=0) profiler = pl.profiler.PyTorchProfiler(export_to_chrome=True, schedule=sched, dirpath=experiment_dir) profiler.STEP_FUNCTIONS = {"training_step"} # only profile training else: profiler = pl.profiler.PassThroughProfiler() trainer.profiler = profiler trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate trainer.fit(lit_model, datamodule=data) trainer.profiler = pl.profiler.PassThroughProfiler() # turn profiling off during testing best_model_path = checkpoint_callback.best_model_path if best_model_path: rank_zero_info(f"Best model saved at: {best_model_path}") if args.wandb: rank_zero_info("Best model also uploaded to W&B ") trainer.test(datamodule=data, ckpt_path=best_model_path) else: trainer.test(lit_model, datamodule=data) if __name__ == "__main__": main() ================================================ FILE: lab08/training/stage_model.py ================================================ """Stages a model for use in production. If based on a checkpoint, the model is converted to torchscript, saved locally, and uploaded to W&B. If based on a model that is already converted and uploaded, the model file is downloaded locally. For details on how the W&B artifacts backing the checkpoints and models are handled, see the documenation for stage_model.find_artifact. """ import argparse from pathlib import Path import tempfile import torch import wandb from text_recognizer.lit_models import TransformerLitModel from training.util import setup_data_and_model_from_args # these names are all set by the pl.loggers.WandbLogger MODEL_CHECKPOINT_TYPE = "model" BEST_CHECKPOINT_ALIAS = "best" MODEL_CHECKPOINT_PATH = "model.ckpt" LOG_DIR = Path("training") / "logs" STAGED_MODEL_TYPE = "prod-ready" # we can choose the name of this type, and ideally it's different from checkpoints STAGED_MODEL_FILENAME = "model.pt" # standard nomenclature; pytorch_model.bin is also used PROJECT_ROOT = Path(__file__).resolve().parents[1] LITMODEL_CLASS = TransformerLitModel api = wandb.Api() DEFAULT_ENTITY = api.default_entity DEFAULT_FROM_PROJECT = "fsdl-text-recognizer-2022-training" DEFAULT_TO_PROJECT = "fsdl-text-recognizer-2022-training" DEFAULT_STAGED_MODEL_NAME = "paragraph-text-recognizer" PROD_STAGING_ROOT = PROJECT_ROOT / "text_recognizer" / "artifacts" def main(args): prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name prod_staging_directory.mkdir(exist_ok=True, parents=True) entity = _get_entity_from(args) # if we're just fetching an already compiled model if args.fetch: # find it and download it staged_model = f"{entity}/{args.from_project}/{args.staged_model_name}:latest" artifact = download_artifact(staged_model, prod_staging_directory) print_info(artifact) return # and we're done # otherwise, we'll need to download the weights, compile the model, and save it with wandb.init( job_type="stage", project=args.to_project, dir=LOG_DIR ): # log staging to W&B so prod and training are connected # find the model checkpoint and retrieve its artifact name and an api handle ckpt_at, ckpt_api = find_artifact( entity, args.from_project, type=MODEL_CHECKPOINT_TYPE, alias=args.ckpt_alias, run=args.run ) # get the run that produced that checkpoint logging_run = get_logging_run(ckpt_api) print_info(ckpt_api, logging_run) metadata = get_checkpoint_metadata(logging_run, ckpt_api) # create an artifact for the staged, deployable model staged_at = wandb.Artifact(args.staged_model_name, type=STAGED_MODEL_TYPE, metadata=metadata) with tempfile.TemporaryDirectory() as tmp_dir: # download the checkpoint to a temporary directory download_artifact(ckpt_at, tmp_dir) # reload the model from that checkpoint model = load_model_from_checkpoint(metadata, directory=tmp_dir) # save the model to torchscript in the staging directory save_model_to_torchscript(model, directory=prod_staging_directory) # upload the staged model so it can be downloaded elsewhere upload_staged_model(staged_at, from_directory=prod_staging_directory) def find_artifact(entity: str, project: str, type: str, alias: str, run=None): """Finds the artifact of a given type with a given alias under the entity and project. Parameters ---------- entity The name of the W&B entity under which the artifact is logged. project The name of the W&B project under which the artifact is logged. type The name of the type of the artifact. alias : str The alias for this artifact. This alias must be unique within the provided type for the run, if provided, or for the project, if the run is not provided. run : str Optionally, the run in which the artifact is located. Returns ------- Tuple[path, artifact] An identifying path and an API handle for a matching artifact. """ if run is not None: path = _find_artifact_run(entity, project, type=type, run=run, alias=alias) else: path = _find_artifact_project(entity, project, type=type, alias=alias) return path, api.artifact(path) def get_logging_run(artifact): api_run = artifact.logged_by() return api_run def print_info(artifact, run=None): if run is None: run = get_logging_run(artifact) full_artifact_name = f"{artifact.entity}/{artifact.project}/{artifact.name}" print(f"Using artifact {full_artifact_name}") artifact_url_prefix = f"https://wandb.ai/{artifact.entity}/{artifact.project}/artifacts/{artifact.type}" artifact_url_suffix = f"{artifact.name.replace(':', '/')}" print(f"View at URL: {artifact_url_prefix}/{artifact_url_suffix}") print(f"Logged by {run.name} -- {run.project}/{run.entity}/{run.id}") print(f"View at URL: {run.url}") def get_checkpoint_metadata(run, checkpoint): config = run.config out = {"config": config} try: ckpt_filename = checkpoint.metadata["original_filename"] out["original_filename"] = ckpt_filename metric_key = checkpoint.metadata["ModelCheckpoint"]["monitor"] metric_score = checkpoint.metadata["score"] out[metric_key] = metric_score except KeyError: pass return out def download_artifact(artifact_path, target_directory): """Downloads the artifact at artifact_path to the target directory.""" if wandb.run is not None: # if we are inside a W&B run, track that we used this artifact artifact = wandb.use_artifact(artifact_path) else: # otherwise, just download the artifact via the API artifact = api.artifact(artifact_path) artifact.download(root=target_directory) return artifact def load_model_from_checkpoint(ckpt_metadata, directory): config = ckpt_metadata["config"] args = argparse.Namespace(**config) _, model = setup_data_and_model_from_args(args) # load LightningModule from checkpoint pth = Path(directory) / MODEL_CHECKPOINT_PATH lit_model = LITMODEL_CLASS.load_from_checkpoint(checkpoint_path=pth, args=args, model=model, strict=False) lit_model.eval() return lit_model def save_model_to_torchscript(model, directory): scripted_model = model.to_torchscript(method="script", file_path=None) path = Path(directory) / STAGED_MODEL_FILENAME torch.jit.save(scripted_model, path) def upload_staged_model(staged_at, from_directory): staged_at.add_file(Path(from_directory) / STAGED_MODEL_FILENAME) wandb.log_artifact(staged_at) def _find_artifact_run(entity, project, type, run, alias): run_name = f"{entity}/{project}/{run}" api_run = api.run(run_name) artifacts = api_run.logged_artifacts() match = [art for art in artifacts if alias in art.aliases and art.type == type] if not match: raise ValueError(f"No artifact with alias {alias} found at {run_name} of type {type}") if len(match) > 1: raise ValueError(f"Multiple artifacts ({len(match)}) with alias {alias} found at {run_name} of type {type}") return f"{entity}/{project}/{match[0].name}" def _find_artifact_project(entity, project, type, alias): project_name = f"{entity}/{project}" api_project = api.project(project, entity=entity) api_artifact_types = api_project.artifacts_types() # loop through all artifact types in this project for artifact_type in api_artifact_types: if artifact_type.name != type: continue # skipping those that don't match type collections = artifact_type.collections() # loop through all artifacts and their versions for collection in collections: versions = collection.versions() for version in versions: if alias in version.aliases: # looking for the first one that matches the alias return f"{project_name}/{version.name}" raise ValueError(f"Artifact with alias {alias} not found in type {type} in {project_name}") raise ValueError(f"Artifact type {type} not found. {project_name} could be private or not exist.") def _get_entity_from(args): entity = args.entity if entity is None: raise RuntimeError(f"No entity argument provided. Use --entity=DEFAULT to use {DEFAULT_ENTITY}.") elif entity == "DEFAULT": entity = DEFAULT_ENTITY return entity def _setup_parser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--fetch", action="store_true", help=f"If provided, check ENTITY/FROM_PROJECT for an artifact with the provided STAGED_MODEL_NAME and download its latest version to {PROD_STAGING_ROOT}/STAGED_MODEL_NAME.", ) parser.add_argument( "--entity", type=str, default=None, help=f"Entity from which to download the checkpoint. Note that checkpoints are always uploaded to the logged-in wandb entity. Pass the value 'DEFAULT' to also download from default entity, which is currently {DEFAULT_ENTITY}.", ) parser.add_argument( "--from_project", type=str, default=DEFAULT_FROM_PROJECT, help=f"Project from which to download the checkpoint. Default is {DEFAULT_FROM_PROJECT}", ) parser.add_argument( "--to_project", type=str, default=DEFAULT_TO_PROJECT, help=f"Project to which to upload the compiled model. Default is {DEFAULT_TO_PROJECT}.", ) parser.add_argument( "--run", type=str, default=None, help=f"Optionally, the name of a run to check for an artifact of type {MODEL_CHECKPOINT_TYPE} that has the provided CKPT_ALIAS. Default is None.", ) parser.add_argument( "--ckpt_alias", type=str, default=BEST_CHECKPOINT_ALIAS, help=f"Alias that identifies which model checkpoint should be staged.The artifact's alias can be set manually or programmatically elsewhere. Default is {BEST_CHECKPOINT_ALIAS!r}.", ) parser.add_argument( "--staged_model_name", type=str, default=DEFAULT_STAGED_MODEL_NAME, help=f"Name to give the staged model artifact. Default is {DEFAULT_STAGED_MODEL_NAME!r}.", ) return parser if __name__ == "__main__": parser = _setup_parser() args = parser.parse_args() main(args) ================================================ FILE: lab08/training/tests/test_memorize_iam.sh ================================================ #!/bin/bash set -uo pipefail set +e # tests whether we can achieve a criterion loss # on a single batch within a certain number of epochs FAILURE=false # constants and CLI args set by aiming for <5 min test on commodity GPU, # including data download step MAX_EPOCHS="${1:-100}" # syntax for basic optional arguments in bash CRITERION="${2:-1.0}" # train on GPU if it's available GPU=$(python -c 'import torch; print(int(torch.cuda.is_available()))') python ./training/run_experiment.py \ --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --limit_test_batches 0.0 --overfit_batches 1 --num_sanity_val_steps 0 \ --augment_data false --tf_dropout 0.0 \ --gpus "$GPU" --precision 16 --batch_size 16 --lr 0.0001 \ --log_every_n_steps 25 --max_epochs "$MAX_EPOCHS" --num_workers 2 --wandb || FAILURE=true python -c "import json; loss = json.load(open('training/logs/wandb/latest-run/files/wandb-summary.json'))['train/loss']; assert loss < $CRITERION" || FAILURE=true if [ "$FAILURE" = true ]; then echo "Memorization test failed at loss criterion $CRITERION" exit 1 fi echo "Memorization test passed at loss criterion $CRITERION" exit 0 ================================================ FILE: lab08/training/tests/test_model_development.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false CI="${CI:-false}" if [ "$CI" = false ]; then export WANDB_PROJECT="fsdl-testing-2022" else export WANDB_PROJECT="fsdl-testing-2022-ci" fi echo "training smaller version of real model class on real data" python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \ --limit_train_batches 1 --limit_val_batches 1 --limit_test_batches 1 --num_sanity_val_steps 0 \ --num_workers 1 --wandb || FAILURE=true TRAIN_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//") echo "staging trained model from run $TRAIN_RUN" python training/stage_model.py --entity DEFAULT --run "$TRAIN_RUN" --staged_model_name test-dummy --ckpt_alias latest --to_project "$WANDB_PROJECT" --from_project "$WANDB_PROJECT" || FAILURE=true echo "fetching staged model" python training/stage_model.py --entity DEFAULT --fetch --from_project $WANDB_PROJECT --staged_model_name test-dummy || FAILURE=true STAGE_RUN=$(find ./training/logs/wandb/latest-run/* | grep -Eo "run-([[:alnum:]])+\.wandb" | sed -e "s/^run-//" -e "s/\.wandb//") if [ "$FAILURE" = true ]; then echo "Model development test failed" echo "cleaning up local files" rm -rf text_recognizer/artifacts/test-dummy echo "leaving remote files in place" exit 1 fi echo "cleaning up local and remote files" rm -rf text_recognizer/artifacts/test-dummy python training/cleanup_artifacts.py --entity DEFAULT --project "$WANDB_PROJECT" \ --run_ids "$TRAIN_RUN" "$STAGE_RUN" --all -v # note: if $TRAIN_RUN and $STAGE_RUN are not set, this will fail. # that's good because it avoids all artifacts from the project being deleted due to the --all. echo "Model development test passed" exit 0 ================================================ FILE: lab08/training/tests/test_run_experiment.sh ================================================ #!/bin/bash set -uo pipefail set +e FAILURE=false echo "running full loop test with CNN on fake data" python training/run_experiment.py --data_class=FakeImageData --model_class=CNN --conv_dim=2 --fc_dim=2 --loss=cross_entropy --num_workers=4 --max_epochs=1 || FAILURE=true echo "running fast_dev_run test of real model class on real data" python training/run_experiment.py --data_class=IAMParagraphs --model_class=ResnetTransformer --loss=transformer \ --tf_dim 4 --tf_fc_dim 2 --tf_layers 2 --tf_nhead 2 --batch_size 2 --lr 0.0001 \ --fast_dev_run --num_sanity_val_steps 0 \ --num_workers 1 || FAILURE=true if [ "$FAILURE" = true ]; then echo "Test for run_experiment.py failed" exit 1 fi echo "Tests for run_experiment.py passed" exit 0 ================================================ FILE: lab08/training/util.py ================================================ """Utilities for model development scripts: training and staging.""" import argparse import importlib DATA_CLASS_MODULE = "text_recognizer.data" MODEL_CLASS_MODULE = "text_recognizer.models" def import_class(module_and_class_name: str) -> type: """Import class from a module, e.g. 'text_recognizer.models.MLP'.""" module_name, class_name = module_and_class_name.rsplit(".", 1) module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def setup_data_and_model_from_args(args: argparse.Namespace): data_class = import_class(f"{DATA_CLASS_MODULE}.{args.data_class}") model_class = import_class(f"{MODEL_CLASS_MODULE}.{args.model_class}") data = data_class(args) model = model_class(data_config=data.config(), args=args) return data, model ================================================ FILE: overview.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "7yQQTA9IGDt8" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "MX9n-Zed8G_T" }, "source": [ "# Lab 00: The 🥞 Full Stack 🥞 of the Text Recognizer Application" ] }, { "cell_type": "markdown", "metadata": { "id": "OggjLhU3f9gk" }, "source": [ "In the course of these labs,\n", "you will build an optical character recognition (OCR) application\n", "that is powered by a neural network:\n", "the \"FSDL Text Recognizer\".\n", "\n", "We use this application to\n", "- demonstrate general principles for engineering an ML-powered application,\n", "- provide a \"worked example\" that includes all of the juicy details, and\n", "- introduce you to tools, libraries, and practices that we consider best-in-class or best for independent ML engineers working across the full stack.\n", "\n", "You can try it out inside this notebook below,\n", "or you can simply navigate to the `app_url` in your browser." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g9xKjSYie6ck" }, "outputs": [], "source": [ "from IPython.display import IFrame\n", "\n", "app_url = \"https://fsdl-text-recognizer.ngrok.io/\"\n", "\n", "IFrame(app_url, width=1024, height=896)" ] }, { "cell_type": "markdown", "metadata": { "id": "BaDkEosIjcl6" }, "source": [ "## Frontend and Backend" ] }, { "cell_type": "markdown", "metadata": { "id": "cDxvNgFHgM_J" }, "source": [ "What you see above is the \"frontend\",\n", "the user-facing component, of the application.\n", "\n", "Frontend web development is typically done using\n", "Javascript as the programming language.\n", "Most ML is done in Python (see below),\n", "so we will instead build our frontend using\n", "the Python library [**Gradio**](https://gradio.app/).\n", "\n", "> Another excellent choice for pure Python web development might be\n", "[Streamlit](https://streamlit.io/)\n", "or even, in the near future, tools built around\n", "[PyScript](https://pyscript.net/).\n", "\n", "Notice the option to \"flag\" the model's outputs.\n", "This user feedback will be sent to [**Gantry**](https://gantry.io/),\n", "where we can monitor model performance,\n", "generate alerts,\n", "and do exploratory data analysis.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ywyH6kW5uUjH" }, "source": [ "\n", "The model that reads the image to produce the text\n", "is not running\n", "in the same place as the frontend.\n", "The model is the \"backend\" of our application.\n", "We separate the two via a JSON API.\n", "The model is deployed\n", "[serverlessly](https://serverless-stack.com/chapters/what-is-serverless.html)\n", "to Amazon Web Services using\n", "[**AWS Lambda**](https://aws.amazon.com/lambda/),\n", "which runs a\n", "[**Docker**](https://docker-curriculum.com/)\n", "container that wraps up our model.\n", "\n", "> Docker is the tool of choice for virtualization/containerization. As containerized applications become more complex,\n", "[container orchestration](https://www.vmware.com/topics/glossary/content/container-management.html)\n", "becomes important. The premier tool for orchestrating\n", "Docker containers is\n", "[kubernetes](https://kubernetes.io/), aka k8s.\n", "Non-experts on cloud infrastructure will want to use their providers' managed service for k8s, e.g.\n", "[AWS EKS](https://aws.amazon.com/eks/)\n", "or [Google Kubernetes Engine](https://cloud.google.com/kubernetes-engine).\n", "\n", "The container image lives inside the\n", "[Elastic Container Registry](https://aws.amazon.com/ecr/),\n", "a sort of \"GitHub for Docker\" on AWS.\n", "The choice to go serverless makes it effortless to scale up our model\n", "across a number of orders of magnitude\n", "and the choice to containerize reduces friction and error\n", "when moving our model from development to production.\n", "\n", "> This could equally as well be done on another cloud,\n", "like [Google Cloud Platform](https://cloud.google.com/)\n", "or [Microsoft Azure](https://azure.microsoft.com/en-us/),\n", "which offer serverless deployment via\n", "[Google Cloud Functions](https://cloud.google.com/serverless)\n", "and [Azure Functions](https://azure.microsoft.com/en-us/solutions/serverless),\n", "respectively. " ] }, { "cell_type": "markdown", "source": [ "The backend operates completely independently of the frontend,\n", "which means it can be used in multiple contexts.\n", "\n", "Run the cell below to send a query directly to the model on the backend." ], "metadata": { "id": "3XBHox87IJ8i" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "76HwEP2Vzz3F" }, "outputs": [], "source": [ "import json # JavaScript Object Notation is the lingua franca of the web\n", "\n", "from IPython.display import Image\n", "import requests # requests is the preferred library for web requests in Python\n", "\n", "lambda_url = \"https://3akxma777p53w57mmdika3sflu0fvazm.lambda-url.us-west-1.on.aws/\"\n", "image_url = \"https://fsdl-public-assets.s3-us-west-2.amazonaws.com/paragraphs/a01-077.png\"\n", "\n", "headers = {\"Content-type\": \"application/json\"} # headers ensure our request is handled correctly\n", "payload = json.dumps({\"image_url\": image_url}) # the request content is a string representation of JSON data\n", "\n", "if \"pred\" not in locals(): # a poor man's cache: if we've defined the variable pred, skip the request\n", " response = requests.post( # we POST the image to the URL, expecting a prediction as a response\n", " lambda_url, data=payload, headers=headers)\n", " pred = response.json()[\"pred\"] # the response is also json\n", "\n", "print(pred)\n", "\n", "Image(url=image_url, width=512)" ] }, { "cell_type": "markdown", "metadata": { "id": "csthw0QlgeSy" }, "source": [ "## Application Diagram" ] }, { "cell_type": "markdown", "metadata": { "id": "baYhDRKkggNk" }, "source": [ "We're only halfway through describing how the Text Recognizer works\n", "and it's already getting hard to hold the whole thing in-memory.\n", "\n", "Run the cell below to show a diagram that incorporates the entire\n", "process for creating and running the Text Recognizer,\n", "from training to feedback collection." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bsOa6gQ0YhX4" }, "outputs": [], "source": [ "diagram_url = \"https://miro.com/app/live-embed/uXjVOrOHcOg=/?moveToViewport=-1210,-1439,2575,1999\"\n", "\n", "IFrame(diagram_url, width=1024, height=512)" ] }, { "cell_type": "markdown", "metadata": { "id": "RiQgHY6Th67H" }, "source": [ "## Model Training" ] }, { "cell_type": "markdown", "metadata": { "id": "Ib6ijsumjjlm" }, "source": [ "Let's start back at the beginning -- developing a model.\n", "We'll then make our way back to where we left off above, the handoff\n", "from model development/training\n", "to the actual application.\n", "\n", "We begin by training a neural network\n", "(a [ResNet](https://pytorch.org/hub/pytorch_vision_resnet/)\n", "encoder to process the images and \n", "a [Transformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)\n", "decoder to produce the output text).\n", "\n", "Neural networks operate by applying\n", "sequences of large matrix multiplications\n", "and other array operations.\n", "These operations are much faster on GPUs than on CPUs\n", "and are relatively easy to parallelize\n", "across GPUs.\n", "This is especially true during training,\n", "where many inputs are processed in parallel,\n", "or \"batched\" together.\n", "\n", "Purchasing GPUs and properly setting up\n", "a multi-GPU machine is challenging\n", "and has high up-front costs.\n", "So we run our training via a cloud provider,\n", "specifically\n", "[**Lambda Labs GPU Cloud**](https://lambdalabs.com/service/gpu-cloud).\n", "\n", "> Other cloud providers offer GPU-accelerated compute\n", "but Lambda Labs offers it at the lowest prices,\n", "as of August 2022.\n", "Larger organizations may benefit from the extra features\n", "that integration with larger cloud providers,\n", "like AWS or GCP, can provide (e.g. unified authorization\n", "and control planes).\n", "Because independent, full-stack developers\n", "are often very price-sensitive, we recommend Lambda Labs --\n", "even more, we recommend checking current and historical instance prices.\n", "\n", "\n", "For smaller units of work, like debugging and quick experiments,\n", "we can use\n", "[Google Colaboratory](https://research.google.com/colaboratory/),\n", "which provides limited access to free GPU (and TPU)\n", "compute in an ephemeral environment.\n", "\n", "> For small-to-medium-sized deep learning tasks,\n", "Colab Pro (\\$10/mo.) and Colab Pro+ (\\$50/mo.)\n", "can be competitive with the larger cloud providers.\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "LUS001EIv3H7" }, "source": [ "If you're running this notebook on a machine with a GPU,\n", "e.g. on Colab, running the cell below\n", "will show some basic information on the GPU's state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WyYVgQmlv091" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "C-C6iWKAsZI3" }, "source": [ "Because the heavy work is done on the GPU,\n", "using lower-level libraries,\n", "we don't need to write the majority of our model development code\n", "in a performant language like C/C++ or Rust.\n", "\n", "We can instead write in a more comfortable, but slower language:\n", "it doesn't make sense to drive an F1 car to the airport\n", "for an international flight.\n", "\n", "The language of choice for deep learning is\n", "[**Python**](https://www.python.org/).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FQALjrGVwFeG" }, "outputs": [], "source": [ "import this # The Zen of Python" ] }, { "cell_type": "markdown", "metadata": { "id": "-uqfsWUQwEyl" }, "source": [ "We don't want to write our Python library for GPU acceleration from scratch,\n", "especially because we also need automatic differentiation --\n", "the ability to take derivatives of our neural networks.\n", "The Python/C++ library\n", "[PyTorch](https://pytorch.org/)\n", "offers GPU-accelerated array math with automatic differentiation,\n", "plus special neural network primitives and architectures.\n", "\n", "> There are two major alternatives to PyTorch\n", "for providing accelerated, differentiable array math,\n", "both from Google: early mover\n", "[TensorFlow](https://www.tensorflow.org/resources/learn-ml)\n", "and new(ish)comer\n", "[JAX](https://github.com/google/jax).\n", "The former is more common in certain larger, older enterprise settings\n", "and the latter is more common in certain bleeding-edge research settings.\n", "We choose PyTorch to split the difference,\n", "but can confidently recommend all three.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qcvCJ6b1wVRl" }, "outputs": [], "source": [ "import torch\n", "\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # run on GPU if available\n", "\n", "# create an array/tensor and track its gradients during calculations\n", "a = torch.tensor([1.], requires_grad=True) \\\n", " .to(device) # store the array data on GPU (if available)\n", "b = torch.tensor([2.]).to(device)\n", "\n", "# calculate new values, building up a \"compute graph\"\n", "c = a * b + a\n", "\n", "# compute gradient of c with respect to a by \"tracing the graph backwards\"\n", "g, = torch.autograd.grad(outputs=c, inputs=a)\n", "\n", "g" ] }, { "cell_type": "markdown", "metadata": { "id": "4zjQyN4HwUS0" }, "source": [ "\n", "PyTorch provides a number of features required for creating\n", "deep neural networks,\n", "but it doesn't include a high-level framework\n", "for training or any of a number of related engineering tasks,\n", "like metric calculation or model checkpointing.\n", "\n", "We use the\n", "[PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "library as our high-level training engineering framework.\n", "\n", "> PyTorch Lightning is the framework of choice\n", "for generic deep learning in PyTorch,\n", "but in natural language processing,\n", "many people instead choose libraries from\n", "[Hugging Face](https://hugginface.co/).\n", "[Keras](https://keras.io/)\n", "is the framework of choice for TensorFlow.\n", "In some ways,\n", "[Flax](https://github.com/google/flax)\n", "is the same for JAX;\n", "in others, there is not as of July 2022 a high-level\n", "training framework in JAX." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ehvGUApGpnrV" }, "outputs": [], "source": [ "from IPython.display import YouTubeVideo\n", "\n", "lit_video_id = \"QHww1JH7IDU\"\n", "YouTubeVideo(lit_video_id, modestbranding=True, rel=False, width=512)" ] }, { "cell_type": "markdown", "metadata": { "id": "7gu9b4Ux1U-k" }, "source": [ "## Experiment and Artifact Tracking" ] }, { "cell_type": "markdown", "metadata": { "id": "vWhTbHq3sfON" }, "source": [ "ML models are challenging to debug:\n", "their inputs and outputs are often easy for humans to interpret\n", "but hard for traditional software programs to understand.\n", "\n", "They are also challenging to design:\n", "there are a number of knobs to twiddle and constants to set,\n", "like a finicky bunch of compiler flags.\n", "These are known as \"hyperparameters\".\n", "\n", "So building an ML model often looks a bit less like engineering\n", "and a bit more like experimentation.\n", "These experiments need to be tracked,\n", "as do large binary files,\n", "or artifacts,\n", "that are produced during those experiments\n", "-- like model weights.\n", "\n", "We choose\n", "[Weights & Biases](http://docs.wandb.ai)\n", "as our experiment and artifact tracking platform.\n", "\n", "> [MLFlow](https://github.com/mlflow/mlflow)\n", "is an open-source library that provides similar\n", "features to W&B, but the experiment and artifact\n", "tracking server must be self-hosted,\n", "which can be burdensome for the already beleaguered\n", "full-stack ML developer.\n", "Basic experiment tracking can also be done\n", "using [Tensorboard](https://www.tensorflow.org/tensorboard),\n", "and shared using [tensorboard.dev](https://tensorboard.dev/),\n", "but Tensorboard does not provide artifact tracking.\n", "Artifact tracking and versioning can be done using\n", "[Git LFS](https://git-lfs.github.com/),\n", "but storage and distribution via GitHub can be expensive\n", "and it does not provide experiment tracking.\n", "[Hugging Face](https://huggingface.co/) runs an alternative\n", "git server, Hugging Face Spaces,\n", "that can display Tensorboard experiments and\n", "mandates Git LFS for large files (where large means >10MB).\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "rL1uL-SewukM" }, "source": [ "The resulting experiment logs can be made very rich\n", "and are invaluable for debugging\n", "(e.g. tracking bugs through the git history)\n", "and communicating results inside and across teams.\n", "\n", "Run the cell below to display the logs from an experiment\n", "that was run while designing and debugging the Text Recognizer model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Uw4LUYkgwvb0" }, "outputs": [], "source": [ "experiment_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/runs/lfjjnxw8\"\n", "\n", "IFrame(experiment_url, width=1024, height=768)" ] }, { "cell_type": "markdown", "metadata": { "id": "O5WHO_CTwrgf" }, "source": [ "Logged _data_ is inert.\n", "It becomes usable, actionable _information_\n", "when it is given context and form.\n", "\n", "Run the cell below to take a look at a dashboard,\n", "built inside W&B,\n", "reporting the results of a training run." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "W0V-qbh8uWwb" }, "outputs": [], "source": [ "dashboard_url= \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/reports/Training-Run-2022-06-02--VmlldzoyMTAyOTkw\"\n", "\n", "IFrame(dashboard_url, width=1024, height=768)" ] }, { "cell_type": "markdown", "metadata": { "id": "R3I60PY61IXH" }, "source": [ "## The Handoff to Production" ] }, { "cell_type": "markdown", "metadata": { "id": "IsT1P1UG1hXW" }, "source": [ "PyTorch Lightning produces large artifacts called \"checkpoints\"\n", "that can be used to restart model training when it stops or is interrupted\n", "(which allows the use of much cheaper\n", "[\"preemptible\" cloud instances](https://www.determined.ai/blog/scale-your-model-development-on-a-budget)).\n", "\n", "We store these artifacts on Weights & Biases.\n", "\n", "When they are ready to be deployed to production,\n", "we compile these model checkpoints down to a dialect of Torch called\n", "[torchscript](https://pytorch.org/docs/stable/jit.html)\n", "that is more portable:\n", "it drops the training engineering code\n", "and produces an artifact that is executable in C++ or in Python.\n", "We stick with a Python environment for simplicity.\n", "\n", "> TensorFlow has similar facilities\n", "for delivering models, including\n", "[tensorflow.js](https://www.tensorflow.org/js)\n", "and [TensorFlow Extended (TFX)](https://www.tensorflow.org/tfx).\n", "There are also a number of alternative portable runtime environments\n", "for ML models, including\n", "[ONNX RT](https://onnx.ai/).\n", "\n", "These deployable models are also stored on Weights & Biases,\n", "which connects them to rich metadata,\n", "including the experiments and training runs\n", "that produced the checkpoints from which they were derived.\n", "\n", "Run the cell below to review the metadata for a deployable\n", "version of the Text Recognizer model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AB8OYJ423Qvy" }, "outputs": [], "source": [ "artifact_url = \"https://wandb.ai/cfrye59/fsdl-text-recognizer-2021-training/artifacts/prod-ready/paragraph-text-recognizer/v8\"\n", "\n", "IFrame(artifact_url, width=1024, height=768)" ] }, { "cell_type": "markdown", "metadata": { "id": "4O6VGIqM3ugW" }, "source": [ "We can pull this file down,\n", "package it into a Docker container\n", "via a small Python script,\n", "and ship it off to a container registry, like AWS ECR or Docker Hub, so that\n", "it can be used to provide the backend to our application." ] }, { "cell_type": "markdown", "metadata": { "id": "CS-1UA1s1hsl" }, "source": [ "## Application Diagram, Redux" ] }, { "cell_type": "markdown", "metadata": { "id": "eqr39GhA313z" }, "source": [ "Now that we have made it through the\n", "🥞 full stack 🥞 of the Text Recognizer application,\n", "let's take a look at the application diagram again." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SH1acJ8f1kpP" }, "outputs": [], "source": [ "IFrame(diagram_url, width=1024, height=512)" ] }, { "cell_type": "markdown", "source": [ "Over the remainder of the labs,\n", "we will put all of these pieces together,\n", "learning more about the problems they solve,\n", "the tradeoffs they make,\n", "and how they are best used." ], "metadata": { "id": "j_Yy4d3Dpi3o" } } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "Lab 00 - Overview.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3.7.13 ('fsdl-text-recognizer-2022')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "vscode": { "interpreter": { "hash": "0f056848cf5d2396a4970b625f23716aa539c2ff5334414c1b5d98d7daae66f6" } } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: pyproject.toml ================================================ [tool.flake8] # configured in .flake8 [tool.darglint] # configured in .flake8 [tool.black] line-length = 120 target-version = ["py37"] [tool.mypy] ignore_missing_imports = true exclude = ["training/logs"] [tool.pytest.ini_options] markers = [ "slow: marks a test as slow (deselect with '-m \"not slow\"']", "data: marks a test as dependent on a data download (deselect with '-m \"not data\"')" ] addopts = "--cov training --cov text_recognizer --cov-branch --doctest-modules --ignore training/logs -m 'not data' --ignore-glob **/bootstrap.py" ================================================ FILE: readme.md ================================================ # 🥞 Full Stack Deep Learning Fall 2022 Labs Welcome! As part of Full Stack Deep Learning 2022, we will incrementally develop a complete deep learning codebase to create and deploy a model that understands the content of hand-written paragraphs. For an overview of the Text Recognizer application architecture, click the badge below to open an interactive Jupyter notebook on Google Colab:


We will use the modern stack of [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://www.pytorchlightning.ai/). We will use the main workhorses of DL today: CNNs and Transformers. We will manage our experiments using what we believe to be the best tool for the job: [Weights & Biases](https://docs.wandb.ai/). We will set up a quality assurance and continuous integration system for our codebase using [pre-commit](https://pre-commit.com/) and [GitHub Actions](https://docs.github.com/en/actions). We will package up the prediction system and deploy it as a [Docker](https://docs.docker.com/) container on [AWS Lambda](https://aws.amazon.com/lambda/). We will wrap that prediction system in a frontend written in Python using [Gradio](https://gradio.app/docs). We will set up monitoring that alerts us to potential issues in our model using [Gantry](https://gantry.io/). ## Click the badges below to access individual lab notebooks on Colab and videos on YouTube | Lab | Colab | Video | | :-- | :---: | :---: | | **Lab Overview** | [![open-in-colab]](https://fsdl.me/lab00-colab) | [![yt-logo]](https://fsdl.me/2022-lab-overview-video) | | **Lab 01: Deep Neural Networks in PyTorch** | [![open-in-colab]](https://fsdl.me/lab01-colab) | [![yt-logo]](https://fsdl.me/2022-lab-01-video) | | **Lab 02a: PyTorch Lightning** | [![open-in-colab]](https://fsdl.me/lab02a-colab) | [![yt-logo]](https://fsdl.me/2022-lab-02-video) | | **Lab 02b: Training a CNN on Synthetic Handwriting Data** | [![open-in-colab]](https://fsdl.me/lab02b-colab) | [![yt-logo]](https://fsdl.me/2022-lab-02-video) | | **Lab 03: Transformers and Paragraphs** | [![open-in-colab]](https://fsdl.me/lab03-colab) | [![yt-logo]](https://fsdl.me/2022-lab-03-video) | | **Lab 04: Experiment Tracking** | [![open-in-colab]](https://fsdl.me/lab04-colab) | [![yt-logo]](https://fsdl.me/2022-lab-04-video) | | **Lab 05: Troubleshooting & Testing** | [![open-in-colab]](https://fsdl.me/lab05-colab) | [![yt-logo]](https://fsdl.me/2022-lab-05-video) | | **Lab 06: Data Annotation** | [![open-in-colab]](https://fsdl.me/lab06-colab) | [![yt-logo]](https://fsdl.me/2022-lab-06-video) | | **Lab 07: Deployment** | [![open-in-colab]](https://fsdl.me/lab07-colab) | [![yt-logo]](https://fsdl.me/2022-lab-07-video) | | **Lab 08: Monitoring** | [![open-in-colab]](https://fsdl.me/lab08-colab) | [![yt-logo]](https://fsdl.me/2022-lab-08-video) | [yt-logo]: https://fsdl.me/yt-logo-badge [open-in-colab]: https://colab.research.google.com/assets/colab-badge.svg ================================================ FILE: requirements/dev-lint.in ================================================ -c prod.txt -c dev.txt bandit==1.7.4 black==22.3.0 darglint==1.8.1 flake8<4 flake8-bandit==3.0.0 flake8-bugbear==22.4.25 flake8-docstrings==1.6.0 flake8-import-order==0.18.1 mypy==0.960 # mypy version also pinned in .pre-commit-config.yaml safety==1.10.3 shellcheck-py==0.8.0.4 types-toml==0.10.7 ================================================ FILE: requirements/dev.in ================================================ -c prod.txt boltons coverage[toml] defusedxml itermplot ipywidgets matplotlib nltk pre-commit pytest pytest-cov scipy toml # versioned to give pip hints coverage[toml]==6.4 pytest==7.1.1 pytest-cov==3.0.0 # versioned to match Google Colab notebook>=6.5,<6.6 nbconvert>=6.5,<6.6 seaborn>=0.13,<0.14 tornado>=6.3,<6.4 # versioned to improve stability pytorch-lightning==1.6.3 torchmetrics<0.8 wandb==0.12.17 ================================================ FILE: requirements/dev.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile requirements/dev.in # absl-py==1.4.0 # via tensorboard aiohttp==3.8.5 # via # -c requirements/prod.txt # fsspec aiosignal==1.3.1 # via # -c requirements/prod.txt # aiohttp anyio==3.7.1 # via # -c requirements/prod.txt # jupyter-server argon2-cffi==23.1.0 # via # jupyter-server # nbclassic # notebook argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.3 # via isoduration asttokens==2.2.1 # via stack-data async-timeout==4.0.3 # via # -c requirements/prod.txt # aiohttp attrs==23.1.0 # via # -c requirements/prod.txt # aiohttp # jsonschema # pytest # referencing backcall==0.2.0 # via ipython beautifulsoup4==4.12.2 # via nbconvert bleach==6.0.0 # via nbconvert boltons==23.0.0 # via -r requirements/dev.in cachetools==4.2.4 # via # -c requirements/prod.txt # google-auth certifi==2023.7.22 # via # -c requirements/prod.txt # requests # sentry-sdk cffi==1.15.1 # via argon2-cffi-bindings cfgv==3.4.0 # via pre-commit charset-normalizer==3.2.0 # via # -c requirements/prod.txt # aiohttp # requests click==8.1.7 # via # -c requirements/prod.txt # nltk # wandb comm==0.1.4 # via # ipykernel # ipywidgets contourpy==1.1.0 # via # -c requirements/prod.txt # matplotlib coverage[toml]==6.4 # via # -r requirements/dev.in # pytest-cov cycler==0.11.0 # via # -c requirements/prod.txt # matplotlib debugpy==1.6.7.post1 # via ipykernel decorator==5.1.1 # via ipython defusedxml==0.7.1 # via # -r requirements/dev.in # nbconvert distlib==0.3.7 # via virtualenv docker-pycreds==0.4.0 # via wandb entrypoints==0.4 # via # jupyter-client # nbconvert exceptiongroup==1.1.3 # via # -c requirements/prod.txt # anyio executing==1.2.0 # via stack-data fastjsonschema==2.18.0 # via nbformat filelock==3.12.2 # via # -c requirements/prod.txt # torch # triton # virtualenv fonttools==4.42.1 # via # -c requirements/prod.txt # matplotlib fqdn==1.5.1 # via jsonschema frozenlist==1.4.0 # via # -c requirements/prod.txt # aiohttp # aiosignal fsspec[http]==2023.6.0 # via # -c requirements/prod.txt # pytorch-lightning # torch gitdb==4.0.10 # via gitpython gitpython==3.1.32 # via wandb google-auth==2.22.0 # via # google-auth-oauthlib # tensorboard google-auth-oauthlib==1.0.0 # via tensorboard grpcio==1.57.0 # via tensorboard identify==2.5.27 # via pre-commit idna==3.4 # via # -c requirements/prod.txt # anyio # jsonschema # requests # yarl iniconfig==2.0.0 # via pytest ipykernel==6.25.1 # via # nbclassic # notebook ipython==8.14.0 # via # ipykernel # ipywidgets ipython-genutils==0.2.0 # via # nbclassic # notebook ipywidgets==8.1.0 # via -r requirements/dev.in isoduration==20.11.0 # via jsonschema itermplot==0.331 # via -r requirements/dev.in jedi==0.19.0 # via ipython jinja2==3.1.2 # via # -c requirements/prod.txt # jupyter-server # nbclassic # nbconvert # notebook # torch joblib==1.3.2 # via nltk jsonpointer==2.4 # via jsonschema jsonschema[format-nongpl]==4.19.0 # via # -c requirements/prod.txt # jupyter-events # nbformat jsonschema-specifications==2023.7.1 # via # -c requirements/prod.txt # jsonschema jupyter-client==7.4.9 # via # ipykernel # jupyter-server # nbclassic # nbclient # notebook jupyter-core==5.3.1 # via # ipykernel # jupyter-client # jupyter-server # nbclassic # nbclient # nbconvert # nbformat # notebook jupyter-events==0.7.0 # via jupyter-server jupyter-server==2.7.2 # via # nbclassic # notebook-shim jupyter-server-terminals==0.4.4 # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert jupyterlab-widgets==3.0.8 # via ipywidgets kiwisolver==1.4.5 # via # -c requirements/prod.txt # matplotlib lxml==4.9.3 # via nbconvert markdown==3.4.4 # via tensorboard markupsafe==2.1.3 # via # -c requirements/prod.txt # jinja2 # nbconvert # werkzeug matplotlib==3.7.2 # via # -c requirements/prod.txt # -r requirements/dev.in # itermplot # seaborn matplotlib-inline==0.1.6 # via # ipykernel # ipython mistune==0.8.4 # via nbconvert mpmath==1.3.0 # via # -c requirements/prod.txt # sympy multidict==6.0.4 # via # -c requirements/prod.txt # aiohttp # yarl nbclassic==1.0.0 # via notebook nbclient==0.8.0 # via nbconvert nbconvert==6.5.4 # via # -r requirements/dev.in # jupyter-server # nbclassic # notebook nbformat==5.9.2 # via # jupyter-server # nbclassic # nbclient # nbconvert # notebook nest-asyncio==1.5.7 # via # ipykernel # jupyter-client # nbclassic # notebook networkx==3.1 # via # -c requirements/prod.txt # torch nltk==3.8.1 # via -r requirements/dev.in nodeenv==1.8.0 # via pre-commit notebook==6.5.5 # via -r requirements/dev.in notebook-shim==0.2.3 # via nbclassic numpy==1.25.2 # via # -c requirements/prod.txt # contourpy # itermplot # matplotlib # pandas # pytorch-lightning # scipy # seaborn # tensorboard # torchmetrics nvidia-cublas-cu12==12.1.3.1 # via # -c requirements/prod.txt # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch nvidia-cuda-cupti-cu12==12.1.105 # via # -c requirements/prod.txt # torch nvidia-cuda-nvrtc-cu12==12.1.105 # via # -c requirements/prod.txt # torch nvidia-cuda-runtime-cu12==12.1.105 # via # -c requirements/prod.txt # torch nvidia-cudnn-cu12==8.9.2.26 # via # -c requirements/prod.txt # torch nvidia-cufft-cu12==11.0.2.54 # via # -c requirements/prod.txt # torch nvidia-curand-cu12==10.3.2.106 # via # -c requirements/prod.txt # torch nvidia-cusolver-cu12==11.4.5.107 # via # -c requirements/prod.txt # torch nvidia-cusparse-cu12==12.1.0.106 # via # -c requirements/prod.txt # nvidia-cusolver-cu12 # torch nvidia-nccl-cu12==2.18.1 # via # -c requirements/prod.txt # torch nvidia-nvjitlink-cu12==12.3.101 # via # -c requirements/prod.txt # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via # -c requirements/prod.txt # torch oauthlib==3.2.2 # via requests-oauthlib overrides==7.4.0 # via jupyter-server packaging==23.1 # via # -c requirements/prod.txt # ipykernel # jupyter-server # matplotlib # nbconvert # pytest # pytorch-lightning # torchmetrics pandas==2.0.3 # via # -c requirements/prod.txt # seaborn pandocfilters==1.5.0 # via nbconvert parso==0.8.3 # via jedi pathtools==0.1.2 # via wandb pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython pillow==9.4.0 # via # -c requirements/prod.txt # matplotlib platformdirs==3.10.0 # via # jupyter-core # virtualenv pluggy==1.2.0 # via pytest pre-commit==3.3.3 # via -r requirements/dev.in prometheus-client==0.17.1 # via # jupyter-server # nbclassic # notebook promise==2.3 # via wandb prompt-toolkit==3.0.39 # via ipython protobuf==3.20.3 # via # tensorboard # wandb psutil==5.9.5 # via # ipykernel # wandb ptyprocess==0.7.0 # via # pexpect # terminado pure-eval==0.2.2 # via stack-data py==1.11.0 # via pytest pyasn1==0.5.0 # via # pyasn1-modules # rsa pyasn1-modules==0.3.0 # via google-auth pycparser==2.21 # via cffi pydeprecate==0.3.2 # via # pytorch-lightning # torchmetrics pygments==2.16.1 # via # ipython # nbconvert pyparsing==3.0.9 # via # -c requirements/prod.txt # matplotlib pytest==7.1.1 # via # -r requirements/dev.in # pytest-cov pytest-cov==3.0.0 # via -r requirements/dev.in python-dateutil==2.8.2 # via # -c requirements/prod.txt # jupyter-client # matplotlib # pandas # wandb python-json-logger==2.0.7 # via jupyter-events pytorch-lightning==1.6.3 # via -r requirements/dev.in pytz==2023.3 # via # -c requirements/prod.txt # pandas pyyaml==6.0.1 # via # -c requirements/prod.txt # jupyter-events # pre-commit # pytorch-lightning # wandb pyzmq==24.0.1 # via # ipykernel # jupyter-client # jupyter-server # nbclassic # notebook referencing==0.30.2 # via # -c requirements/prod.txt # jsonschema # jsonschema-specifications # jupyter-events regex==2023.8.8 # via # -c requirements/prod.txt # nltk requests==2.31.0 # via # -c requirements/prod.txt # fsspec # requests-oauthlib # tensorboard # wandb requests-oauthlib==1.3.1 # via google-auth-oauthlib rfc3339-validator==0.1.4 # via # jsonschema # jupyter-events rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events rpds-py==0.9.2 # via # -c requirements/prod.txt # jsonschema # referencing rsa==4.9 # via google-auth scipy==1.11.2 # via -r requirements/dev.in seaborn==0.13.1 # via -r requirements/dev.in send2trash==1.8.2 # via # jupyter-server # nbclassic # notebook sentry-sdk==1.29.2 # via wandb setproctitle==1.3.2 # via wandb shortuuid==1.0.11 # via wandb six==1.16.0 # via # -c requirements/prod.txt # asttokens # bleach # docker-pycreds # google-auth # itermplot # promise # python-dateutil # rfc3339-validator # wandb smmap==5.0.0 # via gitdb sniffio==1.3.0 # via # -c requirements/prod.txt # anyio soupsieve==2.4.1 # via beautifulsoup4 stack-data==0.6.2 # via ipython sympy==1.12 # via # -c requirements/prod.txt # torch tensorboard==2.14.0 # via pytorch-lightning tensorboard-data-server==0.7.1 # via tensorboard terminado==0.17.1 # via # jupyter-server # jupyter-server-terminals # nbclassic # notebook tinycss2==1.2.1 # via nbconvert toml==0.10.2 # via -r requirements/dev.in tomli==2.0.1 # via # coverage # pytest torch==2.1.1 # via # -c requirements/prod.txt # pytorch-lightning # torchmetrics torchmetrics==0.7.3 # via # -r requirements/dev.in # pytorch-lightning tornado==6.3.3 # via # -r requirements/dev.in # ipykernel # jupyter-client # jupyter-server # nbclassic # notebook # terminado tqdm==4.66.1 # via # -c requirements/prod.txt # nltk # pytorch-lightning traitlets==5.9.0 # via # comm # ipykernel # ipython # ipywidgets # jupyter-client # jupyter-core # jupyter-events # jupyter-server # matplotlib-inline # nbclassic # nbclient # nbconvert # nbformat # notebook triton==2.1.0 # via # -c requirements/prod.txt # torch typing-extensions==4.7.1 # via # -c requirements/prod.txt # pytorch-lightning # torch tzdata==2023.3 # via # -c requirements/prod.txt # pandas uri-template==1.3.0 # via jsonschema urllib3==1.26.16 # via # -c requirements/prod.txt # google-auth # requests # sentry-sdk virtualenv==20.24.3 # via pre-commit wandb==0.12.17 # via -r requirements/dev.in wcwidth==0.2.6 # via prompt-toolkit webcolors==1.13 # via jsonschema webencodings==0.5.1 # via # bleach # tinycss2 websocket-client==1.6.2 # via jupyter-server werkzeug==2.3.7 # via tensorboard wheel==0.41.2 # via tensorboard widgetsnbextension==4.0.8 # via ipywidgets yarl==1.9.2 # via # -c requirements/prod.txt # aiohttp # The following packages are considered to be unsafe in a requirements file: # setuptools ================================================ FILE: requirements/prod.in ================================================ h5py importlib-metadata>=4.4 numpy pyngrok>=6.0,<6.1 requests smart_open[s3] tqdm # versioned for stability gantry==0.4.9 gradio==3.40.1 # versioned to match Google Colab up to minor Jinja2>=3.1,<3.2 pillow>=9.4,<9.5 torch>=2.1,<2.2 torchvision>=0.16,<0.17 ================================================ FILE: requirements/prod.txt ================================================ # # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile requirements/prod.in # aiofiles==23.2.1 # via gradio aiohttp==3.8.5 # via gradio aiosignal==1.3.1 # via aiohttp altair==5.0.1 # via gradio annotated-types==0.5.0 # via pydantic anyio==3.7.1 # via # httpcore # starlette async-timeout==4.0.3 # via aiohttp attrs==23.1.0 # via # aiohttp # jsonschema # referencing boto3==1.28.34 # via # boto3-extensions # smart-open boto3-extensions==0.20.0 # via gantry botocore==1.31.34 # via # boto3 # boto3-extensions # s3transfer cachetools==4.2.4 # via gantry certifi==2023.7.22 # via # httpcore # httpx # requests charset-normalizer==3.2.0 # via # aiohttp # requests click==8.1.7 # via # gantry # uvicorn click-spinner==0.1.10 # via gantry colorama==0.4.6 # via # gantry # halo # log-symbols contourpy==1.1.0 # via matplotlib cycler==0.11.0 # via matplotlib dateparser==1.1.8 # via gantry exceptiongroup==1.1.3 # via anyio fastapi==0.101.1 # via gradio ffmpy==0.3.1 # via gradio filelock==3.12.2 # via # huggingface-hub # torch # triton fonttools==4.42.1 # via matplotlib frozenlist==1.4.0 # via # aiohttp # aiosignal fsspec==2023.6.0 # via # gradio-client # huggingface-hub # torch gantry==0.4.9 # via -r requirements/prod.in gradio==3.40.1 # via -r requirements/prod.in gradio-client==0.5.0 # via gradio h11==0.14.0 # via # httpcore # uvicorn h5py==3.9.0 # via -r requirements/prod.in halo==0.0.31 # via gantry httpcore==0.17.3 # via httpx httpx==0.24.1 # via # gradio # gradio-client huggingface-hub==0.16.4 # via # gradio # gradio-client idna==3.4 # via # anyio # httpx # requests # yarl importlib-metadata==6.8.0 # via -r requirements/prod.in importlib-resources==6.0.1 # via gradio isodate==0.6.1 # via gantry jinja2==3.1.2 # via # -r requirements/prod.in # altair # gradio # torch jmespath==1.0.1 # via # boto3 # botocore jsonschema==4.19.0 # via altair jsonschema-specifications==2023.7.1 # via jsonschema kiwisolver==1.4.5 # via matplotlib linkify-it-py==2.0.2 # via markdown-it-py log-symbols==0.0.14 # via halo markdown-it-py[linkify]==2.2.0 # via # gradio # mdit-py-plugins markupsafe==2.1.3 # via # gradio # jinja2 marshmallow==3.20.1 # via # gantry # marshmallow-oneofschema marshmallow-oneofschema==3.0.1 # via gantry matplotlib==3.7.2 # via gradio mdit-py-plugins==0.3.3 # via gradio mdurl==0.1.2 # via markdown-it-py monotonic==1.6 # via gantry mpmath==1.3.0 # via sympy multidict==6.0.4 # via # aiohttp # yarl networkx==3.1 # via torch numpy==1.25.2 # via # -r requirements/prod.in # altair # contourpy # gantry # gradio # h5py # matplotlib # pandas # torchvision nvidia-cublas-cu12==12.1.3.1 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch nvidia-cuda-cupti-cu12==12.1.105 # via torch nvidia-cuda-nvrtc-cu12==12.1.105 # via torch nvidia-cuda-runtime-cu12==12.1.105 # via torch nvidia-cudnn-cu12==8.9.2.26 # via torch nvidia-cufft-cu12==11.0.2.54 # via torch nvidia-curand-cu12==10.3.2.106 # via torch nvidia-cusolver-cu12==11.4.5.107 # via torch nvidia-cusparse-cu12==12.1.0.106 # via # nvidia-cusolver-cu12 # torch nvidia-nccl-cu12==2.18.1 # via torch nvidia-nvjitlink-cu12==12.3.101 # via # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch orjson==3.9.5 # via gradio packaging==23.1 # via # gradio # gradio-client # huggingface-hub # marshmallow # matplotlib pandas==2.0.3 # via # altair # gantry # gradio pillow==9.4.0 # via # -r requirements/prod.in # gradio # matplotlib # torchvision pydantic==2.3.0 # via # fastapi # gradio pydantic-core==2.6.3 # via pydantic pydub==0.25.1 # via gradio pyngrok==6.0.0 # via -r requirements/prod.in pyparsing==3.0.9 # via matplotlib python-dateutil==2.8.2 # via # botocore # dateparser # gantry # matplotlib # pandas python-multipart==0.0.6 # via gradio pytz==2023.3 # via # dateparser # pandas pyyaml==6.0.1 # via # gantry # gradio # huggingface-hub # pyngrok referencing==0.30.2 # via # jsonschema # jsonschema-specifications regex==2023.8.8 # via dateparser requests==2.31.0 # via # -r requirements/prod.in # gantry # gradio # gradio-client # huggingface-hub # torchvision rpds-py==0.9.2 # via # jsonschema # referencing s3transfer==0.6.2 # via boto3 semantic-version==2.10.0 # via gradio six==1.16.0 # via # halo # isodate # python-dateutil smart-open[s3]==6.3.0 # via -r requirements/prod.in sniffio==1.3.0 # via # anyio # httpcore # httpx spinners==0.0.24 # via halo starlette==0.27.0 # via fastapi sympy==1.12 # via torch tabulate==0.9.0 # via gantry termcolor==2.3.0 # via halo toolz==0.12.0 # via altair torch==2.1.1 # via # -r requirements/prod.in # torchvision torchvision==0.16.1 # via -r requirements/prod.in tqdm==4.66.1 # via # -r requirements/prod.in # gantry # huggingface-hub triton==2.1.0 # via torch typeguard==2.13.3 # via gantry typing-extensions==4.7.1 # via # altair # fastapi # gantry # gradio # gradio-client # huggingface-hub # pydantic # pydantic-core # torch # uvicorn tzdata==2023.3 # via pandas tzlocal==5.0.1 # via dateparser uc-micro-py==1.0.2 # via linkify-it-py urllib3==1.26.16 # via # botocore # requests uvicorn==0.23.2 # via gradio websockets==11.0.3 # via # gradio # gradio-client yarl==1.9.2 # via aiohttp zipp==3.16.2 # via importlib-metadata ================================================ FILE: setup/readme.md ================================================ # Setup Deep learning requires access to accelerated computation hardware. Most commonly, those are NVIDIA GPUs or Google TPUs. If you have access to a computer that has an NVIDIA GPU and runs Linux, you're welcome to [set it up](#Local) for local use. If you don't, you can get free compute with [Google Colab](#Colab). ## Colab Google Colab is a great way to get access to fast GPUs for free. All you need is a Google account. The preferred way to interact with the labs on Colab is just to click on badges like this one:
All setup is handled automatically, so you can immediately start working on the labs. But if you just want to use the codebase, then go to [https://colab.research.google.com](https://colab.research.google.com) and create a new notebook. Connect your new notebook to a GPU runtime by doing Runtime > Change Runtime type > GPU. ![](colab_runtime.png) Now, run `!nvidia-smi` in the first cell (press Shift+Enter to run a cell). You should see a table showing your precious GPU :) Now, paste the following into a cell and run it: ```py # FSDL 2022 Setup lab_idx = None if "bootstrap" not in locals() or bootstrap.run: # path management for Python pythonpath, = !echo $PYTHONPATH if "." not in pythonpath.split(":"): pythonpath = ".:" + pythonpath %env PYTHONPATH={pythonpath} !echo $PYTHONPATH # get both Colab and local notebooks into the same state !wget --quiet https://fsdl.me/gist-bootstrap -O bootstrap.py import bootstrap # change into the lab directory bootstrap.change_to_lab_dir(lab_idx=lab_idx) # allow "hot-reloading" of modules %load_ext autoreload %autoreload 2 bootstrap.run = False # change to True re-run setup !pwd %ls ``` The bootstrap script will check out our lab repository, `cd` into it, and install required packages. It also allows Python to find packages in the current working directory. From there, you can `%cd` into a lab folder to play around with the codebase for that lab, either by directly writing Python, e.g. `import text_recognizer`, or by running shell commands, like `!python training/run_experiment.py`. ### Colab Pro You may be interested in signing up for [Colab Pro](https://colab.research.google.com/signup). For $10/month, you get priority access to faster GPUs (e.g. [P100 vs K80](https://www.xcelerit.com/computing-benchmarks/insights/nvidia-p100-vs-k80-gpu/)) and TPUs, a 24h rather than 12h maximum runtime, and more RAM. ## Local Setting up a machine that you can sit in front of or SSH into is easy. ### Watch a walkthrough video [here](https://fsdl.me/2022-local-setup-video). If you get stuck, it's better to at least [get started with the labs on Colab](https://fsdl.me/lab00-colab), where setup is just a single click, rather than getting frustrated and burning out on annoying environment management, networking, and systems administration issues that aren't as relevant to making ML-powered products. ### Summary - `environment.yml` specifies Python and optionally CUDA/CUDNN - `make conda-update` creates/updates a virtual environment - `conda activate fsdl-text-recognizer-2022` activates the virtual environment - `requirements/prod.in` and `requirements/dev.in` specify core Python packages in that environment - `make pip-tools` resolves all other Python dependencies and installs them - `export PYTHONPATH=.:$PYTHONPATH` makes the current directory visible on your Python path -- add it to your `~/.bashrc` and `source ~/.bashrc` ### 1. Check out the repo ``` git clone https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022-labs.git cd fsdl-text-recognizer-2022-labs ``` ### 2. Set up the Python environment We use [`conda`](https://docs.conda.io/en/latest/miniconda.html) for managing Python and CUDA versions, and [`pip-tools`](https://github.com/jazzband/pip-tools) for managing Python package dependencies. We add a `Makefile` for making setup dead-simple. #### First: Install the Python + CUDA environment using Anaconda Conda is an open-source package management system and environment management system that runs on Windows, macOS, and Linux. It is most closely associated with Python, but [in fact it can manage more than just Python environments](https://jakevdp.github.io/blog/2016/08/25/conda-myths-and-misconceptions/). To install `conda`, follow instructions at https://conda.io/projects/conda/en/latest/user-guide/install/linux.html. Conda will install the appropriate version of Python for you in the project environment, so it doesn't matter which installer you choose. In the project we use the version of Python used in Google Colab, which at time of writing is Python 3.10. Note that you will likely need to close and re-open your terminal. Afterwards, you should have ability to run the `conda` command in your terminal. Run `make conda-update` to create an environment called `fsdl-text-recognizer-2022`, as defined in `environment.yml`. This environment will provide us with the right Python version as well as the CUDA and CUDNN libraries. If you edit `environment.yml`, just run `make conda-update` again to get the latest changes. Next, activate the conda environment. ```sh conda activate fsdl-text-recognizer-2022 ``` **IMPORTANT**: every time you work in this directory, make sure to start your session with `conda activate fsdl-text-recognizer-2022`. #### Next: install Python packages Next, install all necessary Python packages by running `make pip-tools` Using `pip-tools` lets us do three nice things: 1. Separate out dev from production dependencies (`dev.in` vs `prod.in`). 2. Have a lockfile of exact versions for all dependencies (the auto-generated `dev.txt` and `prod.txt`). 3. Allow us to easily deploy to targets that don't support the `conda` environment, like Colab. #### Set PYTHONPATH Last, run `export PYTHONPATH=.` before executing any commands later on, or you will get errors like this: ```python ModuleNotFoundError: No module named 'text_recognizer' ``` In order to not have to set `PYTHONPATH` in every terminal you open, just add that line as the last line of the `~/.bashrc` file using a text editor of your choice (e.g. `nano ~/.bashrc`) or by concatenating with `>>` ```bash echo "export PYTHONPATH=.:$PYTHONPATH" >> ~/.bashrc ```