Showing preview only (4,493K chars total). Download the full file or copy to clipboard to get everything.
Repository: deep-learning-indaba/indaba-pracs-2022
Branch: main
Commit: 16f8ccfbc141
Files: 10
Total size: 4.3 MB
Directory structure:
gitextract_dumor4er/
├── LICENSE
├── README.MD
└── practicals/
├── Bayesian_Deep_Learning_Prac.ipynb
├── GNN_practical.ipynb
├── Indaba_2022_Prac_Template.ipynb
├── Introduction_to_ML_using_JAX.ipynb
├── array_algebra.ipynb
├── attention_and_transformers.ipynb
├── deep_generative_models.ipynb
└── introduction_to_reinforcement_learning.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2022, Deep Learning Indaba
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.MD
================================================
# Deep Learning Indaba Practicals 2022
## The Practicals
| Topic 💥 | Description 📘 |
|:--- |----------------------------------------------------------|
[Introduction to ML using JAX](https://github.com/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/Introduction_to_ML_using_JAX.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/Introduction_to_ML_using_JAX.ipynb) | In this tutorial, we will learn about JAX, a new machine learning framework that has taken deep learning research by storm! JAX is praised for its speed, and we will learn how to achieve these speedups, using core concepts in JAX, such as automatic differentiation (`grad`), parallelization (`pmap`), vectorization (`vmap`), just-in-time compilation (`jit`), and more. We will then use what we have learned to implement Linear Regression effectively while learning some of the fundamentals of optimization. |
[Bayesian Deep Learning](https://github.com/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/Bayesian_Deep_Learning_Prac.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/Bayesian_Deep_Learning_Prac.ipynb) | Bayesian inference provides us with the tools to update our beliefs consistently when we observe data. Compared to the, more common, loss minimisation approach to learning, Bayesian methods offer us calibrated uncertainty estimates, resistance to overfitting, and even approaches to select hyper-parameters without a validation set. 🚀In this prac we will learn to do all of these things!🚀 |
[Transformers and Attention](https://github.com/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/attention_and_transformers.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/attention_and_transformers.ipynb) | The transformer architecture, introduced in Vaswani et al. 2017's paper [Attention is All You Need](https://arxiv.org/abs/1706.03762?amp=1), has significantly impacted the deep learning field. It has arguably become the de-facto architecture for complex Natural Language Processing (NLP) tasks. It can also be applied in various domains reaching state-of-the-art performance, including computer vision and reinforcement learning. Transformers, as the title of the original paper implies, are almost entirely based on a concept known as attention. Attention allows models to "focus" on different parts of an input; while considering the entire context of the input versus an RNN, that operates on the data sequentially. In this practical, we will introduce attention in greater detail and build the entire transformer architecture block by block to see why it is such a robust and powerful architecture |
[Graph Neural Networks](https://github.com/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/GNN_practical.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/GNN_practical.ipynb) | In this tutorial, we will be learning about Graph Neural Networks (GNNs), a topic which has exploded in popularity in both research and industry. We will start with a refresher on graph theory, then dive into how GNNs work from a high level. Next we will cover some popular GNN implementations and see how they work in practice. |
[Deep Generative Models](https://github.com/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/deep_generative_models.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/deep_generative_models.ipynb) | In this practical, we will investigate the fundamentals of generative modelling – a machine learning framework that allows us to learn how to sample new unseen data points that match the distribution of our training dataset. Generative modelling, though a powerful and flexible framework–which has provided many exciting advances in ML–has its own challenges and limitations. This practical will walk you through such challenges and will illustrate how to solve them by implementing a Denoising Diffusion Model (a.k.a. a Score-Based Generative Model), which is the backbone of the recent and exciting [Dalle-2](https://openai.com/dall-e-2/) and [Imagen](https://imagen.research.google/) models that we’ve all seen on [Twitter](https://twitter.com/search?q=%23dalle2%20%23imagen&src=typed_query). |
[Intro to Reinforcement Learning](https://github.com/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/introduction_to_reinforcement_learning.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/introduction_to_reinforcement_learning.ipynb) | In this tutorial, we will be learning about Reinforcement Learning, a type of Machine Learning where an agent learns to choose actions in an environment that lead to maximal reward in the long run. RL has seen tremendous success on a wide range of challenging problems such as learning to play complex video games like [Atari](https://www.deepmind.com/blog/agent57-outperforming-the-human-atari-benchmark), [StarCraft II](https://www.deepmind.com/blog/alphastar-mastering-the-real-time-strategy-game-starcraft-ii) and [Dota II](https://openai.com/five/). In this introductory tutorial we will solve the classic [CartPole](https://www.gymlibrary.ml/environments/classic_control/cart_pole/) environment, where an agent must learn to balance a pole on a cart, using several different RL approaches. Along the way you will be introduced to some of the most important concepts and terminology in RL. |
[Array Algebra](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/prac-array-algebra/practicals/array_algebra.ipynb) <br /> <br /> [](https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/prac-array-algebra/practicals/array_algebra.ipynb) | In this tutorial, we’ll look at one of the cornerstones of modern deep learning: array programming. Think of this tutorial as a gym or obstacle course for working with arrays. We’ll help you gain intuition about multidimensional arrays (“tensors”), so you can better understand transposing, broadcasting, contraction, and the other “algebraic moves” that make array programming both frustrating and rewarding. They’ll be practical examples as well as theoretical exercises to help you develop your understanding. |
This repository contains the practical notebooks for the Deep Learning Indaba
2022, held at SUP’COM University in Tunis, Tunisia.
See [www.deeplearningindaba.com](http://www.deeplearningindaba.com) for more details.
================================================
FILE: practicals/Bayesian_Deep_Learning_Prac.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "m2s4kN_QPQVe"
},
"source": [
"# Bayesian Deep Learning Practical\n",
"\n",
"<img src=\"https://i.imgur.com/btStvUL.png\" width=\"90%\" />\n",
"\n",
"\n",
"\n",
"<a href=\"https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/Bayesian_Deep_Learning_Prac.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"© Deep Learning Indaba 2022. Apache License 2.0.\n",
"\n",
"**Authors:**\n",
"\n",
"Javier Antorán, James Allingham.\n",
"\n",
"**Reviewers:**\n",
"\n",
"Kale-ab Tessera, Ruan van der Merwe\n",
"\n",
"**Introduction:** \n",
"\n",
"Bayesian inference provides us with the tools to update our beliefs consistently when we observe data. Compared to the, more common, loss minimisation approach to learning, Bayesian methods offer us **calibrated uncertainty estimates**, **resistance to overfitting**, and even approaches to **select hyper-parameters without a validation set**. 🚀In this prac we will learn to do all of these things!🚀\n",
"\n",
"\n",
"\n",
"**Aims/Learning Objectives:**\n",
"\n",
"\n",
"* Understand the tradeoffs of Maximum Likelihood vs fully Bayesian learning **[Sections 0 and 1]**\n",
"* Implement Bayesian linear regression **[Section 1]**\n",
"* Understand the challenges of Bayesian inference in non-conjugate models and the need for approximate inference **[Section 2]**\n",
"* Implement “black box” variational inference **[Section 3]**\n",
"* Understand the tradeoffs of different methods for approximate Bayesian Inference, e.g. variational inference vs Monte Carlo methods **[Sections 4 and 5]**\n",
"\n",
"**Prerequisites:**\n",
"\n",
"* Familiarity with Jax\n",
"* Basic Linear Algebra\n",
"* Basics of Bayesian inference [here is a 15 min video on the topic](https://www.youtube.com/watch?v=HZGCoVF3YvM) and / or Bayesian Inference parallel talk from the Indaba\n",
"* Recommended: Attend the Monte Carlo 101 parallel.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aOSzO91vNdcJ"
},
"source": [
"**Topics:** \n",
"\n",
"\n",
"0. [Warmup - Standard Linear Regression](#S0) \n",
"1. [From Regression to Bayesian Linear Regression](#S1) \n",
"2. [Logistic Regression & the Need for Approximate Inference](#S2) \n",
"3. [Black Box Variational Inference](#S3) \n",
"4. [Bayesian Neural Networks](#S4) \n",
"5. [Hamiltonian Monte Carlo (Optional)](#S5) \n",
"\n",
"\n",
"\n",
"\\\n",
"**Before you start:**\n",
"\n",
"For this practical, you dont need any fancy computers. We are working smarter not harder! Set your runtime to CPU to get a guaranteed fast runtime allocation.\n",
"\n",
"This prac contains excercises and sections at multiple levels of difficulty. These are labelled <font color='green'>`Base`</font>,\n",
" <font color='orange'>`Intermediate`</font> and\n",
" <font color='red'>`Advanced`</font>. Only the <font color='green'>`Base`</font> exercises are part of the core prac. Non-core sections, for instance all of the <font color='red'>`Advanced`</font> ones, are labelled (Optional). If you do not have experience working with matrix calculus or probability disrtibutions, it is best to skip optional sections to ensure you reach the end of the prac.\n",
"\n",
"This prac has a lot of content so if you get stuck it is best to ask a neighbour or tutor for help fast. After every code segment, there are a list of questions for you to answer. You are not expected to know the correct answer to all of these and your responses are not recorded anywhere. The questions are just there to get you thinking about the topic and discussing with others. Dont spend too much time on any single set of questions!\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6EqhIg1odqg0"
},
"source": [
"## Installation and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "4boGA9rYdt9l"
},
"outputs": [],
"source": [
"%%capture\n",
"## Install and import anything required. Capture hides the output from the cell.\n",
"# @title Install and import required packages. (Run Cell)\n",
"\n",
"import subprocess\n",
"import os\n",
"\n",
"# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n",
"try:\n",
" subprocess.check_output('nvidia-smi')\n",
" print(\"a GPU is connected.\")\n",
"except Exception: \n",
" # TPU or CPU\n",
" if \"COLAB_TPU_ADDR\" in os.environ and os.environ[\"COLAB_TPU_ADDR\"]:\n",
" print(\"A TPU is connected.\")\n",
" import jax.tools.colab_tpu\n",
" jax.tools.colab_tpu.setup_tpu()\n",
" else:\n",
" print(\"Only CPU accelerator is connected.\")\n",
" # x8 cpu devices - number of (emulated) host devices\n",
" os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n",
" \n",
"!pip install optax\n",
"!pip install livelossplot\n",
"!pip install numpyro\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jax.numpy.linalg import inv\n",
"from jax import grad, jit, vmap, random, value_and_grad\n",
"from jax.scipy.stats import norm, bernoulli\n",
"from jax.nn import sigmoid, tanh\n",
"from jax.nn.initializers import normal\n",
"from jax.scipy.linalg import solve\n",
"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"import numpy as np\n",
"\n",
"from functools import partial\n",
"\n",
"import optax\n",
"from optax import sigmoid_binary_cross_entropy\n",
"\n",
"\n",
"from livelossplot import PlotLosses"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Aj2sIrAyx3hx"
},
"outputs": [],
"source": [
"# @title Helper Functions. (Run Cell)\n",
"# @markdown You dont need to change these throughout the practical.\n",
"# @markdown But take a look if you are curious! 🐈\n",
"\n",
"import copy\n",
"from typing import Dict\n",
"\n",
"\n",
"def plot_performance(data: Dict, title: str):\n",
" runs = list(data.keys())\n",
" time = list(data.values())\n",
"\n",
" # creating the bar plot\n",
" plt.bar(runs, time, width=0.35)\n",
"\n",
" plt.xlabel(\"Implementation\")\n",
" plt.ylabel(\"Average time taken (in s)\")\n",
" plt.title(title)\n",
" plt.show()\n",
"\n",
" best_perf_key = min(data, key=data.get)\n",
" all_runs_key = copy.copy(runs)\n",
"\n",
" # all_runs_key_except_best\n",
" all_runs_key.remove(best_perf_key)\n",
"\n",
" for k in all_runs_key:\n",
" print(\n",
" f\"{best_perf_key} was {round((data[k]/data[best_perf_key]),2)} times faster than {k} !!!\"\n",
" )\n",
"\n",
"\n",
"def errorfill(\n",
" x,\n",
" y,\n",
" yerr,\n",
" color=None,\n",
" alpha_fill=0.3,\n",
" line_alpha=1,\n",
" ax=None,\n",
" lw=1,\n",
" linestyle=\"-\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=None,\n",
" markevery=None,\n",
"):\n",
" ax = ax if ax is not None else plt.gca()\n",
" if color is None:\n",
" color = ax._get_lines.color_cycle.next()\n",
" if np.isscalar(yerr) or len(yerr) == len(y):\n",
" ymin = y - yerr\n",
" ymax = y + yerr\n",
" elif len(yerr) == 2:\n",
" ymin, ymax = yerr\n",
" plt_return = ax.plot(\n",
" x,\n",
" y,\n",
" color=color,\n",
" lw=lw,\n",
" linestyle=linestyle,\n",
" alpha=line_alpha,\n",
" label=label,\n",
" marker=marker,\n",
" markersize=markersize,\n",
" markevery=markevery,\n",
" )\n",
" ax.fill_between(\n",
" x, ymax, ymin, color=color, alpha=alpha_fill, linewidths=fill_linewidths\n",
" )\n",
" return plt_return\n",
"\n",
"\n",
"def plot_basic_data(\n",
" x_data, y_data, parameters_list=None, title=\"Observed data\", ylim=None\n",
"):\n",
" xlim = [jnp.min(x_data) - 1, jnp.max(x_data) + 1]\n",
"\n",
" if ylim is None:\n",
" ylim = [jnp.min(y_data) - 1, jnp.max(y_data) + 1]\n",
" fig, ax = plt.subplots()\n",
"\n",
" if parameters_list is not None:\n",
" x_pred = np.linspace(xlim[0], xlim[1], 100)\n",
" for parameters in parameters_list:\n",
" y_pred = parameters[0] + parameters[1] * x_pred\n",
" ax.plot(x_pred, y_pred, \":\", color=[1, 0.7, 0.6])\n",
"\n",
" parameters = parameters_list[-1]\n",
" y_pred = parameters[0] + parameters[1] * x_pred\n",
" ax.plot(x_pred, y_pred, \"-\", color=[1, 0, 0], lw=2)\n",
"\n",
" ax.plot(x_data, y_data, \"ob\")\n",
" ax.set(xlabel=\"Input x\", ylabel=\"Output y\", title=title, xlim=xlim, ylim=ylim)\n",
" ax.grid()\n",
"\n",
"\n",
"def optimise(objective, params, plotting_func, LR, MAX_STEPS, LOG_EVERY):\n",
"\n",
" optimiser = optax.chain(optax.adam(LR, b1=0.99, b2=0.999))\n",
" opt_state = optimiser.init(params)\n",
"\n",
" def gen_update(objective):\n",
" def update(params, opt_state, key):\n",
" # get data neded for training\n",
" value, grads = jax.value_and_grad(objective)(params, key)\n",
" grads = jax.tree_map(lambda x: -x, grads)\n",
" updates, opt_state = optimiser.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" return value, params, opt_state\n",
"\n",
" return jit(update)\n",
"\n",
" update = gen_update(objective)\n",
"\n",
" plotlosses = PlotLosses()\n",
" objective_values = []\n",
" artist_list = []\n",
"\n",
" key = random.PRNGKey(42)\n",
" keys = random.split(key, MAX_STEPS)\n",
" # Training & evaluation loop.\n",
" for step in range(MAX_STEPS):\n",
" objective_value, params, opt_state = update(params, opt_state, keys[step])\n",
" objective_values.append(objective_value)\n",
"\n",
" if step % LOG_EVERY == 0:\n",
" # Plot objective curve\n",
" objective_mean = jnp.array(objective_values).mean()\n",
" plotlosses.update(\n",
" {\n",
" \"objective_value\": objective_mean,\n",
" }\n",
" )\n",
" plotlosses.send()\n",
" objective_values = []\n",
"\n",
" # Plot loss landscape and variational distribution\n",
" if plotting_func is not None:\n",
" plotting_func(params)\n",
"\n",
" return params\n",
"\n",
"\n",
"def generate_loss_grid(loss_fun, grid_size, lim0, lim1):\n",
"\n",
" x0, x1 = np.linspace(lim0[0], lim0[1], num=grid_size), np.linspace(\n",
" lim1[0], lim1[1], num=grid_size\n",
" )\n",
" x0_grid, x1_grid = np.meshgrid(x0, x1)\n",
"\n",
" param_mat = jnp.stack([x0_grid.ravel(), x1_grid.ravel()], axis=0)\n",
"\n",
" vmap_loss_fun = vmap(loss_fun)\n",
" loss_grid = vmap_loss_fun(param_mat.T).reshape(grid_size, grid_size)\n",
" return x0_grid, x1_grid, loss_grid\n",
"\n",
"\n",
"def plot_log_gaussian_ellipse(\n",
" ax,\n",
" mean,\n",
" cov,\n",
" color=\"b\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=None,\n",
" MAP_size=5,\n",
" std_levels=[1, 2, 4, 6],\n",
"):\n",
"\n",
" eigenvalues, eigenvectors = jnp.linalg.eigh(cov)\n",
" theta = np.linspace(0, 2 * np.pi, 1200)\n",
" std_ellipsis = (np.sqrt(eigenvalues[None, :]) * eigenvectors) @ jnp.stack(\n",
" [np.sin(theta), np.cos(theta)]\n",
" )\n",
"\n",
" artists = []\n",
"\n",
" for level in std_levels:\n",
" artists.append(\n",
" ax.plot(\n",
" mean[0] + level * std_ellipsis[0, :],\n",
" mean[1] + level * std_ellipsis[1, :],\n",
" c=color,\n",
" alpha=alpha,\n",
" lw=lw,\n",
" )\n",
" )\n",
"\n",
" artists.append(\n",
" ax.scatter(\n",
" [mean[0]],\n",
" [mean[1]],\n",
" MAP_size,\n",
" color=color,\n",
" label=label,\n",
" alpha=alpha,\n",
" marker=\"x\",\n",
" )\n",
" )\n",
"\n",
" return artists"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "mSBrgb2SzR4Q"
},
"outputs": [],
"source": [
"# @title Check the device you are using (Run Cell)\n",
"# @markdown Just CPU is fine!\n",
"print(f\"Num devices: {jax.device_count()}\")\n",
"print(f\" Devices: {jax.devices()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w8CdWv4wZQPL"
},
"source": [
"# Section 0 - Warmup <a name=\"S0\"></a>\n",
"# You are a born Bayesian learning machine!\n",
"\n",
"\n",
"This section introduces the basic concepts that will be used throught the rest of the prac: affine linear models, weight norm regularisation and basic optimisation. Dont spent too much time on this section -- 20 mins max.\n",
"\n",
"\n",
"## Standard Linear Regression\n",
"\n",
"In regression, we aim to find a function $f$ that maps inputs $x \\in R$ to corresponding outputs $y \\in R$. To start, lets choose the affine function $f(x, \\color{blue}{w, b}) = \\color{blue}{w} x + \\color{blue}{b}$ where the learnable parameters $(\\color{blue}{w, b}) \\in \\mathcal{R}^2$ are <font color='blue'>` painted blue`</font>. \n",
"\n",
"Sidenote: \"affine\" means: \"composed of a multiplication and a sum\".\n",
"\n",
"In other words, we assume that the value of the nth target $y_n$ can be modelled as\n",
"\n",
"$y_n = \\underbrace{\\color{blue}{w} x_n + \\color{blue}{b}}_{f(x_n, \\color{blue}{w, b})} + \\color{red}{\\epsilon_n}$\n",
"\n",
"with $\\color{red}{\\epsilon_n}$ refering to residual noise (we will explain what this is below). We do not want that noise to be learnt by $f$.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "gcqpxisITKYW"
},
"outputs": [],
"source": [
"# @title First dataset\n",
"# @markdown Let's build a simple dataset, with 6 $(x,y)$ pairs.\n",
"\n",
"x_data_list_outlier = jnp.array([1, 1.5, 2, 3, 4, 5])\n",
"x_data_list_outlier = (\n",
" x_data_list_outlier - x_data_list_outlier.mean()\n",
") / x_data_list_outlier.std() #\n",
"y_data_list_outlier = jnp.array([3.1, 2.1, 1.7, 1.3, -5, -0.8]) - 3\n",
"\n",
"plot_basic_data(x_data_list_outlier, y_data_list_outlier, ylim=[-5, 1])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "14KthOtTZRod"
},
"source": [
"### Finding plausible solutions manually\n",
"\n",
"When faced with some data, there are often multiple plausible explanations for it, that is multiple plausible functions that fit the data well. Lets try to come up with some plausible functions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FLvxEOBtWrSF"
},
"source": [
"**Code Task:** <font color='green'>`Base`</font>\n",
"1. Move the two sliders below to set $\\color{blue}{b}$ and $\\color{blue}{w}$, and press \"Run cell\" on the code cell below. \n",
"2. Is your $f(x)$ a good fit for the blue data points?\n",
"3. Repeat 1-2 until you have between 5 and 10 red lines you are happy with."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A_8hyJrhdy6v"
},
"outputs": [],
"source": [
"parameters_list = [] # Used to track which parameters were tried."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "iYl7LM7kWYNG"
},
"outputs": [],
"source": [
"# @title { run: \"auto\" }\n",
"b = -1.9 # @param {type:\"slider\", min:-5, max:5, step:0.01}\n",
"w = -1.21 # @param {type:\"slider\", min:-5, max:5, step:0.01}\n",
"print(\"Plotting line\", w, \"* x +\", b)\n",
"parameters = [b, w]\n",
"parameters_list.append(parameters)\n",
"plot_basic_data(\n",
" x_data_list_outlier,\n",
" y_data_list_outlier,\n",
" parameters_list,\n",
" title=\"Observed data and my first predictions\",\n",
" ylim=[-5, 1],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jnWK16WVj8ah"
},
"outputs": [],
"source": [
"manual_parameters_list = jnp.array(\n",
" parameters_list\n",
") # turn your chosen parameter values into a jnp array"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3t_g88TNN-Ok"
},
"source": [
"### Now with Machine learning:\n",
"\n",
"Now that we have manually solved the problem, we are going to find some regression functions with machine learning."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9qxh20jncNDf"
},
"source": [
"#### Introducing matrix notation! \n",
"\n",
"For convinience we are going to work with matrix notation.\n",
"\n",
"\n",
"Our targets can be written as a vector $\\underline{y} = [y_0, y_1, ...]^T$ and the same for our inputs $\\underline{x} = [x_0, x_1, ...]^T$. These vectors are of length matching our number of observations $N=6$.\n",
"\n",
"\n",
"The bias (or offset) $\\color{blue}{b}$ in our regression function $f(x, \\color{blue}{w, b}) = \\color{blue}{b + w} x $ can be written as a product of an additional weight $\\color{blue}{b}$ with an input that takes constant value of $1$. That is $\\color{blue}{b + w} x = \\color{blue}{b} 1 + \\color{blue}{w}x$. Now we have a weight vector $\\color{blue}{\\underline{w} = [b, w] }$ of dimension $D=2$ and a feature vector $[1, x]$. We can write our vectorised function as $f(x, \\color{blue}{\\underline{w}}) = [1, x]^T \\color{blue}{\\underline{w}}$. \n",
"\n",
"We refer to the mapping $x \\to [1, x]$ as a mapping into the *affine basis*. \n",
"\n",
"When we apply this mapping, we can stack all of our observations into a matrix $X = [[1, x_0], [1, x_1], [1, x_2], ...] \\in \\mathcal{R}^{N \\times 2}$.\n",
"\n",
"\n",
"\n",
"Our full dataset is now modelled as $\\underline{y} = X \\cdot \\color{blue}{\\underline{w}} + \\color{red}{\\underline{\\epsilon}}$\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "52V8T_cOOZXK"
},
"source": [
"**Coding task:**<font color='green'>`Base`</font>\n",
"\n",
"Implement the mapping onto the affine basis \n",
"\n",
"* Hint: useful methods are `jnp.concatenate` and `jnp.ones`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P0NmMLl3c0ui"
},
"outputs": [],
"source": [
"def affine_basis(x):\n",
" \"\"\"\n",
" Maps a vector of inputs [x_0, x_1, ...]\n",
" onto the affine basis [[x_0, 1], [x_1, 1], [x_2, 1], ...]\n",
"\n",
" Args:\n",
" x: jnp array of shape (N,) or (N,1)\n",
" Returns:\n",
" X: jnp array of shape (N,2)\n",
" \"\"\"\n",
" if x.ndim == 1:\n",
" x = x.copy()[:, None]\n",
"\n",
" # Your code goes\n",
" return X"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "XQmKx0KUnge_"
},
"outputs": [],
"source": [
"# @title Run me to test your code\n",
"\n",
"\n",
"X = affine_basis(x_data_list_outlier)\n",
"\n",
"X_correct = jnp.stack([jnp.ones(len(x_data_list_outlier)), x_data_list_outlier], axis=1)\n",
"\n",
"assert jnp.allclose(X_correct, X), \"X is not calculated correctly\"\n",
"\n",
"print(\"It seems correct. Look at the answer below to compare methods.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "gLQnawymm3sd"
},
"outputs": [],
"source": [
"# @title Answer (Try not to peek until you've given it a good try!')\n",
"\n",
"\n",
"def affine_basis(x):\n",
" \"\"\"\n",
" Maps a vector of inputs [x_0, x_1, ...]\n",
" onto the affine basis [[1, x_0], [1, x_1], [1, x_2], ...]\n",
"\n",
" Args:\n",
" x: jnp array of shape (N,) or (N,1)\n",
" Returns:\n",
" X: jnp array of shape (N,2)\n",
" \"\"\"\n",
" if x.ndim == 1:\n",
" x = x.copy()[:, None]\n",
"\n",
" pad_ones = jnp.ones(\n",
" (len(x), 1)\n",
" ) # the bias can be interpreted as a weight that multiplies an input with a constant value of 1\n",
" X = jnp.concatenate([pad_ones, x], axis=1)\n",
" return X"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9q2I2yrSl9t6"
},
"outputs": [],
"source": [
"# @title Now we implement the linear function\n",
"# @markdown it will take a single point as input and then we will vmap it to deal with batches of inputs\n",
"\n",
"\n",
"def linear(x, w):\n",
" return x @ w\n",
"\n",
"\n",
"vmap_linear = jit(vmap(linear, in_axes=(0, None)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QBAJvQqDlWNS"
},
"source": [
"**Group task**: <font color='green'>`Base`</font> - Discuss with a neighbour:\n",
"\n",
"* What does the @ operator do? \n",
" * Solution: it is the same as [np.dot](https://numpy.org/doc/stable/reference/generated/numpy.dot.html)\n",
"\n",
"* What would have been an alternative way to implement the linear function without using @?\n",
"\n",
"* What if we want to evaluate the function for multiple weight settings $\\color{blue}{\\underline{w}}$, would we need to change our code?\n",
"\n",
"If you are unsure about any of these, get a tutor involved in your discussion!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5VasjmzccNH6"
},
"source": [
"#### Lets find the line that minimises the loss\n",
"\n",
"\n",
"**Code task:** <font color='green'>`Base`</font>\n",
"\n",
"Implement the squared error (SE) loss from the introduction to Jax prac:\n",
"<center>\n",
"$\\frac{1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2$\n",
"</center>\n",
"\n",
"* Hint: $||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 = \\sum_{n=1}^{N} (y_n - x_{n} \\cdot \\color{blue}{\\underline{w}})^2$ is just the squared [euclidean norm](https://en.wikipedia.org/wiki/Norm_(mathematics)). The output is a single number."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cYjm5aFuRbz-"
},
"outputs": [],
"source": [
"def linear_regression_loss(X, y, vmap_model, w):\n",
" \"\"\"\n",
" Computes SE loss\n",
" \n",
" Args:\n",
" X: jnp array of shape (N,2)\n",
" y: jnp array of shape (N,)\n",
" w: jnp array of shape (w,)\n",
" vmap_model: function that maps (X, w) onto y\n",
" Returns:\n",
" squared_error: scalar\n",
" \"\"\"\n",
" predictions = vmap_model(X, w)\n",
" squared_error = # Your code goes here\n",
" return squared_error"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "sYiTU8mDsOCs"
},
"outputs": [],
"source": [
"# @title Run me to test your code\n",
"\n",
"\n",
"mse = linear_regression_loss(\n",
" affine_basis(x_data_list_outlier),\n",
" y_data_list_outlier,\n",
" vmap_linear,\n",
" jnp.array([2, 2]),\n",
")\n",
"\n",
"mse_correct = (\n",
" 0.5\n",
" * (\n",
" (\n",
" vmap_linear(affine_basis(x_data_list_outlier), jnp.array([2, 2]))\n",
" - y_data_list_outlier\n",
" )\n",
" ** 2\n",
" ).sum()\n",
")\n",
"\n",
"assert jnp.allclose(mse_correct, mse), \"X is not calculated correctly\"\n",
"\n",
"print(\"It seems correct. Look at the answer below to compare methods.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xS719wz6sOFU"
},
"outputs": [],
"source": [
"# @title Answer (Try not to peek until you've given it a good try!')\n",
"\n",
"\n",
"def linear_regression_loss(X, y, vmap_model, w):\n",
" \"\"\"\n",
" Computes MSE loss\n",
"\n",
" Args:\n",
" X: jnp array of shape (N,2)\n",
" y: jnp array of shape (N,)\n",
" w: jnp array of shape (w,)\n",
" vmap_model: function that maps (X, w) onto y\n",
" Returns:\n",
" squared_error: scalar\n",
" \"\"\"\n",
" predictions = vmap_model(X, w)\n",
" squared_error = 0.5 * ((y - predictions) ** 2).sum()\n",
" return squared_error"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "JU3TM6P8rVvT"
},
"outputs": [],
"source": [
"# @title Lets optimise our parameters with the squared error loss\n",
"# @markdown Note that in this prac we will be working with objectives which are the inverse of the loss. That means that it is good if the training curve goes up!\n",
"\n",
"\n",
"def gen_objective(X, y, vmap_model):\n",
" def objective(params, key):\n",
" return -linear_regression_loss(X, y, vmap_model, params)\n",
"\n",
" return objective\n",
"\n",
"\n",
"objective = gen_objective(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, vmap_linear\n",
")\n",
"\n",
"optimised_params = optimise(\n",
" objective,\n",
" params=jnp.array([0.0, 0.0]),\n",
" plotting_func=None,\n",
" LR=1e-2,\n",
" MAX_STEPS=1000,\n",
" LOG_EVERY=100,\n",
")\n",
"\n",
"print(\"Your optimised parameter values are\", optimised_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "SW8kWeyM25dz"
},
"outputs": [],
"source": [
"# @title Lets look at how your manually chosen regression functions compare with the optimised one\n",
"\n",
"\n",
"plt.figure(dpi=100)\n",
"\n",
"xlim = [jnp.min(x_data_list_outlier) - 1, jnp.max(x_data_list_outlier) + 1]\n",
"ylim = [jnp.min(y_data_list_outlier) - 1, jnp.max(y_data_list_outlier) + 1]\n",
"\n",
"x_test = jnp.linspace(xlim[0], xlim[1], 201)\n",
"X_test = affine_basis(x_test)\n",
"\n",
"ax = plt.gca()\n",
"\n",
"ax.plot(\n",
" x_test,\n",
" vmap_linear(X_test, manual_parameters_list.T),\n",
" c=\"red\",\n",
" lw=0.5,\n",
" label=\"manual choice\",\n",
")\n",
"ax.plot(\n",
" x_test, vmap_linear(X_test, optimised_params), c=\"purple\", lw=2, label=\"optimised\"\n",
")\n",
"\n",
"ax.plot(x_data_list_outlier, y_data_list_outlier, \"ob\")\n",
"ax.set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
"ax.grid(alpha=0.3)\n",
"ax.set_title(\"Regression function comparison\")\n",
"ax.legend(loc=\"lower left\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6_prh3fO3B5w"
},
"source": [
"**Group task**: <font color='green'>`Base`</font> - Discuss with a neighbour and / or tutor:\n",
"\n",
"* Do your manual choices look similar to those of your neighbour? \n",
"\n",
"* Are you more confident about the functions you chose or the one that the optimiser chose?\n",
"\n",
"* What went wrong with the optimiser, if anything? Hint: can you spot any previously hidden outliers?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-79VgSrB2u2J"
},
"source": [
"#### How can we make our parameters robust to overfitting to outliers? You already know the solution: adding regularisation!\n",
"\n",
"The outlier above pushes our regression line to have a larger negative slope than we would like. We can bias our slope parameter $\\color{blue}{w}$ away from solutions that have a large (possitive or negative) slope by penalising $||\\color{blue}{w}||_{2}^2$. This will prevent the outlier from influencing our regression line too much. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5Wy4eIOrrV0g"
},
"outputs": [],
"source": [
"# @markdown Our *regularised* loss is $\\frac{1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 + \\frac{\\alpha}{2}||\\color{blue}{\\underline{w}}||_{2}^{2}$\n",
"\n",
"# @markdown where $\\alpha$ is the regularisation strength.\n",
"\n",
"\n",
"def regularised_linear_regression_loss(X, y, vmap_model, w, alpha):\n",
" predictions = vmap_model(X, w)\n",
" data_fit = 0.5 * ((y - predictions) ** 2).sum()\n",
" regulariser = 0.5 * alpha * (w**2).sum()\n",
" return data_fit + regulariser"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JbZ6-3aM5Ahr"
},
"source": [
"**Math Task:** <font color='orange'>`Intermediate`</font> (Optional - skip if you are not familiar with matrix calculus)\n",
"\n",
"The regularised least squares solution, i.e. the minimum of $\\frac{1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 + \\frac{\\alpha}{2}||\\color{blue}{\\underline{w}}||_{2}^{2}$ can be calculated in closed form with pen and paper.\n",
"\n",
"Try to derive it!\n",
"\n",
"**Hint1:** take derivate with respect to $\\color{blue}{\\underline{w}}$ and set it to 0 to find the stationary point, i.e. the minimum. \n",
"\n",
"<center>\n",
"$\\frac{\\partial}{\\partial \\color{blue}{\\underline{w}}} \\frac{1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 + \\frac{\\alpha}{2}||\\color{blue}{\\underline{w}}||_{2}^{2} =0\n",
"$\n",
"</center>\n",
"\n",
"**Hint2:** the derivative is:\n",
"\n",
"<center>\n",
"$\\frac{\\partial}{\\partial \\color{blue}{\\underline{w}}} \\frac{1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 + \\frac{\\alpha}{2}||\\color{blue}{\\underline{w}}||_{2}^{2} = X^T (\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}) + \\alpha \\color{blue}{\\underline{w}}\n",
"$\n",
"</center>\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P80NpVB9546H"
},
"outputs": [],
"source": [
"# @markdown What is the solution? (Options given already in code)\n",
"\n",
"# @markdown Hint: `inv(A)` computes A^{-1} and `solve(A,B)` computes A^{-1} B\n",
"\n",
"selection = \"X.T @ X + alpha * jnp.eye(w.shape[0])\" # @param ['X.T @ X + alpha * jnp.eye(w.shape[0])', \"solve(X.T @ X + alpha * jnp.eye(X.shape[1]), X.T @ y)\", \"inv(X.T @ X) @ (\\alpha + X.T @ y)\"]\n",
"print(f\"You selected: {selection}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nS_x2x0Y8Rgs"
},
"outputs": [],
"source": [
"# @title Coding task (Optional): choose one of the 3 options from above\n",
"\n",
"\n",
"def regularised_least_squares_solution(X, y, alpha):\n",
" return # paste solution here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "33Jh6D0-4xmP"
},
"outputs": [],
"source": [
"# @title Lets optimise the parameters with the regulariser and lets look at how the manually chosen regression functions compare with the regularised and unregularised ones <font color='green'>`Base`</font> {run: \"auto\"}\n",
"\n",
"# @markdown **This does not require solving the optional task above**.\n",
"\n",
"# @markdown Lets start with $\\alpha=5$.\n",
"\n",
"# @markdown Play around with different $\\alpha$ values to find the one that gives solutions that \"feel\" the best\n",
"\n",
"alpha = 3.58 # @param {type:\"slider\", min:0, max:20, step:0.01}\n",
"\n",
"\n",
"def gen_objective(X, y, vmap_model, alpha):\n",
" def objective(params, key):\n",
" return -regularised_linear_regression_loss(X, y, vmap_model, params, alpha)\n",
"\n",
" return objective\n",
"\n",
"\n",
"objective = gen_objective(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, vmap_linear, alpha\n",
")\n",
"\n",
"optimised_regularised_params = optimise(\n",
" objective,\n",
" params=jnp.array([0.0, 0.0]),\n",
" plotting_func=None,\n",
" LR=1e-2,\n",
" MAX_STEPS=500,\n",
" LOG_EVERY=100,\n",
")\n",
"\n",
"\n",
"plt.figure(dpi=100)\n",
"\n",
"xlim = [jnp.min(x_data_list_outlier) - 1, jnp.max(x_data_list_outlier) + 1]\n",
"ylim = [jnp.min(y_data_list_outlier) - 1, jnp.max(y_data_list_outlier) + 1]\n",
"\n",
"x_test = jnp.linspace(xlim[0], xlim[1], 201)\n",
"X_test = affine_basis(x_test)\n",
"\n",
"ax = plt.gca()\n",
"\n",
"ax.plot(x_test, vmap_linear(X_test, manual_parameters_list.T), c=\"red\", lw=0.5)\n",
"ax.plot(\n",
" x_test,\n",
" vmap_linear(X_test, optimised_params),\n",
" c=\"purple\",\n",
" lw=2,\n",
" label=\"unregularised ptimised\",\n",
")\n",
"ax.plot(\n",
" x_test,\n",
" vmap_linear(X_test, optimised_regularised_params),\n",
" c=\"cyan\",\n",
" lw=2,\n",
" label=\"regularised optimised\",\n",
")\n",
"\n",
"ax.plot(x_data_list_outlier, y_data_list_outlier, \"ob\")\n",
"ax.set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
"ax.grid(alpha=0.3)\n",
"ax.legend(loc=\"lower left\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xz0V4ggxUjrp"
},
"outputs": [],
"source": [
"# @title (Optional) If you coded the exact solution, lets compare it with the output of the optimiser\n",
"# @markdown Note that both parameter vectors will not be exactly the same due to our optimisation not having fully converged.\n",
"\n",
"# Hint: the correct solution to the above question was (b)\n",
"exact_regularised_params = regularised_least_squares_solution(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, alpha\n",
")\n",
"\n",
"print(\"Your regularised optimised parameter values are\", optimised_regularised_params)\n",
"print(\"Your regularised exact parameter values are\", exact_regularised_params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "FnOovLgA7Lgb"
},
"outputs": [],
"source": [
"# @title Run me to reveal answer to optional math task above, for reference.\n",
"from IPython.display import display, Markdown, Latex\n",
"\n",
"display((\"The derivation of the solution to the regularised least squares minima is:\"))\n",
"display(\n",
" Latex(\n",
" \"$X^T (\\\\underline{y} - X\\cdot {\\\\underline{w}} ) + \\\\alpha {\\\\underline{w}} = 0$\"\n",
" )\n",
")\n",
"display(\n",
" Latex(\n",
" \"$X^T \\\\underline{y} - X^T X {\\\\underline{w}} + \\\\alpha {\\\\underline{w}} = 0$\"\n",
" )\n",
")\n",
"display(Latex(\"$X^T \\\\underline{y} = (X^T X w + \\\\alpha I){\\\\underline{w}} = 0$\"))\n",
"display(Latex(\"$\\\\underline{w} = (X^T X w + \\\\alpha I)^{-1} X^T \\\\underline{y}$\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EnkYeSCf_VFQ"
},
"source": [
"### Lets investigate the loss landscape: <font color='green'>`Base`</font> \n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vkHnvwgtFbYU"
},
"outputs": [],
"source": [
"# @title { run: \"auto\" }\n",
"# @markdown Lets plot the loss for a large grid of $b$ and $w$ values to see which ones would work well\n",
"\n",
"# @markdown Try different values of alpha to see how they affect the loss landscape\n",
"\n",
"\n",
"alpha = 4.18 # @param {type:\"slider\", min:0, max:20, step:0.01}\n",
"\n",
"\n",
"def generate_loss_fun():\n",
" def loss_fun(w):\n",
" return regularised_linear_regression_loss(\n",
" X=affine_basis(x_data_list_outlier),\n",
" y=y_data_list_outlier,\n",
" w=w,\n",
" vmap_model=vmap_linear,\n",
" alpha=alpha,\n",
" )\n",
"\n",
" return loss_fun\n",
"\n",
"\n",
"loss_fun = generate_loss_fun()\n",
"\n",
"x0_grid, x1_grid, loss_grid = generate_loss_grid(\n",
" loss_fun=loss_fun, grid_size=200, lim0=[-8, 8], lim1=[-8, 8]\n",
")\n",
"\n",
"\n",
"plt.figure(dpi=120)\n",
"plt.pcolormesh(\n",
" x0_grid, x1_grid, -loss_grid, vmin=-loss_grid.min(), vmax=-70, cmap=\"viridis\"\n",
") #\n",
"plt.ylabel(\"w\")\n",
"plt.xlabel(\"b\")\n",
"plt.grid(alpha=0.3)\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
"\n",
"plt.scatter(\n",
" manual_parameters_list[:, 0],\n",
" manual_parameters_list[:, 1],\n",
" 10,\n",
" color=\"red\",\n",
" label=\"manually chosen parameters\",\n",
")\n",
"plt.scatter(\n",
" optimised_regularised_params[0],\n",
" optimised_regularised_params[1],\n",
" 10,\n",
" color=\"cyan\",\n",
" label=\"regularised optimisation\",\n",
")\n",
"plt.scatter(\n",
" optimised_params[0],\n",
" optimised_params[1],\n",
" 10,\n",
" color=\"purple\",\n",
" label=\"unregularised optimisation\",\n",
")\n",
"plt.title(\"Value of loss function for different parameter values\")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GFg5FS8Mmpi6"
},
"source": [
"* The loss landscape accomodates multiple solutions! It turns out that all of the parameter combinations in the yellow region of the loss surface have low loss, they could have worked quite well! \n",
"\n",
"\n",
"* For larger values of $\\alpha$, the yellow region shrinks and moves towards the origin. This means that out of all the lines that fit the data, we are only keeping the ones that have a smaller parameter norm $||\\color{blue}{\\underline{w}}||^2_2$.\n",
"\n",
"**Question <font color='green'>`Base`</font>:** Given that we have only seen $N=6$ training points, how confident should we be that the line that minimises the loss function above is the \"right\" line?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yyXD3p5-97T7"
},
"source": [
"### **Takeaways**: trust your instincts!\n",
"\n",
"1. Our intuition about what is a \"good\" solution alligns more with the solutions chosen by the regularised objective. \n",
"\n",
"2. Using our intuition, we were able to come up with multiple plausible options for lines that fit the data. Doesn't it feel kind of wrong that the optimiser only gives us 1 solution?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dIqcTxinEt7a"
},
"source": [
"# Section 1 - From regression to Bayesian linear regression <a name=\"S1\"></a>\n",
"\n",
"\n",
"In this section, we formalise the intuition from the previous section. We are going to learn about Bayesian inference. \n",
"\n",
"Specifically, we are going to introduce the concept of probability distribution and show how to update our beliefs using Bayes rule. We will compare Bayesian learning with traditional loss-minimisation learning. Finally, we will look at how to make predictions within the Bayesian framework.\n",
"\n",
"Hopefully, by the end of the section, you will be convinced that when you were manually choosing regression parameters at the beginning of the practical, your brain was performing Bayesian inference!\n",
"\n",
"If you are familiar with probability distributions and Bayes rule, try to breeze through this section and do the optional parts at the end of the section. If it is your first time dealing with these concepts, spend more time on the core section. Aim to spend 45 minutes on the section."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ZUp8i37dFbU"
},
"source": [
"## Bayesian linear regression fundamentals\n",
"\n",
"\n",
"\n",
"In the Bayesian framework, we pick a model which we believe could have generated the data. We already have such a model:\n",
"\n",
"<center> \n",
"$\\underline{y} = X \\cdot \\color{blue}{\\underline{w}} + \\color{red}{\\underline{\\epsilon}}$,\n",
"</center>\n",
"\n",
"Although the exact value of $\\color{blue}{\\underline{w}}$ is unknown a priori (before observing the data), we do have some prior knowledge we can incorporate into our problem. As we saw in the previous section, we might not want to choose a very large slope for our regression line, since this often happens when we overfit to some outlier. \n",
"\n",
"Mathematically, we express the fact that we do not know what value $\\color{blue}{\\underline{w}}$ should have, but we think it should be low norm, through the choice of a prior probability distribution $p(\\color{blue}{\\underline{w}})$ that places more probability mass near 0.\n",
"\n",
"The most common choice to satisfy this condition, and easiest to work with, is the 0 mean Gaussian:\n",
"\n",
"<center> \n",
"$\\color{blue}{\\underline{w}} \\sim \\mathcal{N}(0, \\alpha^{-1} I)$ where $I$ is the indentity matrix.\n",
"</center>\n",
"\n",
"<center> \n",
"<img src=\"https://i.imgur.com/JWykW0I.png\" width=\"40%\" />\n",
"</center>\n",
"\n",
"The precision parameter $\\alpha$ (inverse of variance) expresses how confident we are that the parameters should be close to 0. If we set $\\alpha=0.001$ we are saying that we have a slight preference for a small regression slope but dont have a strong opinion so we are also happy with large slope solutions. If we set $\\alpha=100$ we are saying that we are very confident that we should only admit very low slope solutions.\n",
"\n",
"Even if we knew the exact value of $\\color{blue}{\\underline{w}}$, we would still be unable to retrieve $\\underline{y}$ given $X$ and $\\color{blue}{\\underline{w}}$. This is because, in general, our observations are corrupted by the residuals $\\color{red}{\\underline{\\epsilon}}$. \n",
"\n",
"$\\color{red}{\\underline{\\epsilon}}$ represents some noise that can not be predicted from $X$ alone, e.g. thermal noise in the measurement of a voltage value or background radiation in radio-astronomy. We usually assume this noise is the cause of outliers, such as the one in our dataset. Since we do not know what value $\\color{red}{\\underline{\\epsilon}}$ takes a priori, we also model $\\color{red}{\\underline{\\epsilon}}$ with a probability distribution $p(\\color{red}{\\underline{\\epsilon}})$.\n",
"\n",
"\n",
"<center> Lets keep it simple and choose the Gaussian\n",
"$ \\,\\,\\,\\color{red}{\\underline{\\epsilon}} \\sim \\mathcal{N}(0, I)$\n",
"</center>\n",
"\n",
"\n",
"We are ready to write our full probabilistic model. Given inputs and weights, our targets follow the distribution\n",
"\n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) = \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I)$ \n",
"\n",
"Nottice that the mean is our prediction $X\\cdot \\color{blue}{\\underline{w}}$ and the covariance is just the covariance of our residuals $I$.\n",
"\n",
"Our weights follow the distribution\n",
"\n",
"$p(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)$\n",
"\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G7qzoHuTFQWN"
},
"source": [
"### (Reminder): Probability density functions and likelihood functions\n",
"\n",
"In this section, we are going to be working with probability distributions over continous random variables. Consider the variable $a \\in \\mathbb{R}$ distributed $p(a) = \\mathcal{N}(a; 0, I)$. Since there are infinite possible values $a$ can take, we assing an infinitesimally small probability to each one. However, not all values of $a$ are equally probable. Since the mean of $\\mathcal{N}(a; 0, I)$ is 0, this is the most probable value. We quantify the relative probability of the different values $a$ can take with the **probability density function $p(a)$**. In this section, we will work with these functions.\n",
"\n",
"Unlike probabilities which are restricted to the interval [0, 1], probability densities can take values in [0, $\\infty$). \n",
"\n",
"\\\n",
"We can compute the probability of $a$ taking a value in the range $[s,s']$ by integrating the density function in this interval $p(a \\in [s,s']) = \\int_{s}^{s'} p(a) da$. Since the probability of all possible events must sum to 1, we have that $\\int_{-\\infty}^{\\infty} p(a) da = 1$.\n",
"\n",
"When we write $p(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)$ we are saying that $\\color{blue}{\\underline{w}}$ has a Gaussian density with mean $0$ and precision $\\alpha$. Given a particular value of $\\color{blue}{\\underline{w}}$, we can evalute its density by evaluating the density function\n",
"\n",
"$p(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I) = \\frac{1}{ (2 \\pi \\alpha)^{D/2} } \\exp(\\frac{-\\alpha \\color{blue}{\\underline{w}}^T \\color{blue}{\\underline{w}}}{2})$.\n",
"\n",
"where $D$ is the size of $\\color{blue}{\\underline{w}}$, i.e. 2. \n",
"\n",
"\\\n",
"Nottice that the terms outside the exponential do not contain $\\color{blue}{\\underline{w}}$. These are known as constants. The term inside the exponential $\\frac{-\\alpha \\color{blue}{\\underline{w}}^T \\color{blue}{\\underline{w}}}{2}$ becomes larger the smaller $\\color{blue}{\\underline{w}}$ becomes. Its maximum is attained when $\\color{blue}{\\underline{w}} = 0$.\n",
"\n",
"Assuming we know the value of the weights $\\color{blue}{\\underline{w}}, our$ targets $\\underline{y}$ follow a probability distribution $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$. The vertical line \"|\" denotes conditioning. $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$ means: \"the probability density for $\\underline{y}$ conditional on choosing a specific $X$ and $\\color{blue}{\\underline{w}}$\". When we expand this density function, we get\n",
"\n",
"\\\n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) = \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I) = \\frac{1}{ (2 \\pi)^{N/2} } \\exp(\\frac{- (y - X \\cdot \\color{blue}{\\underline{w}})^T (y - X \\cdot \\color{blue}{\\underline{w}})}{2})$\n",
"\n",
"where $N$ is the number of observations, i.e. 6.\n",
"\n",
"The distribution $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$ is special\n",
"because it is not a distrbution over the variable we care about $\\color{blue}{\\underline{w}}$. Instead $\\color{blue}{\\underline{w}}$ is a parameter of the distribution, influencing the value of its mean. However, we can still view $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$ as a function of $\\color{blue}{\\underline{w}}$. When viewed in this light, probability densities are known as **likelihood functions**. \n",
"\n",
"Likelihood functions can be used for learning. Specifically, we can choose $\\color{blue}{\\underline{w}}$ such that the probability of our targets conditional on $\\color{blue}{\\underline{w}}$, i.e. $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$ is maximised. This is known as **maximum likelihood criteria**. \n",
"\n",
"<font color='green'>`Fun fact`</font>: A likelihood function does not need to be normalised. As a result, any possitive function can be viewed as a likelihood function. The consequence of this is that any machine learning algorithm that minimises a loss or maximises an objective can be seen as performing maximum likelihood learning.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "siNhkwB6pbNE"
},
"source": [
"## Bayesian learning of our parameters $\\color{blue}{\\underline{w}}$\n",
"\n",
"\n",
"Learning means identifying the parameter settings that \n",
"1. Are compatible with our prior assumptions, i.e. have low norm. -- this is captured by $p(\\color{blue}{\\underline{w}})$\n",
"**and**\n",
"2. Fit the data well. -- this is quantified by the likelihood function $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$\n",
"\n",
"We can score parameter settings by how well they satisfy both constraints simultaneously by considering the product of the prior and likelihood\n",
"\n",
"<center> \n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}})$.\n",
"</center> \n",
"\n",
"Unfortunately, when you multiply two probability distributions, you are not guaranteed to get something that integrates to 1 (a key requirement for something to be a probability distribution). This is why we need to renormalise by dividing the product above by $p(\\underline{y} | X) = \\int p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) d \\color{blue}{\\underline{w}}$. Putting it all together, we have \n",
"\n",
"\n",
"$$\\frac{p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}})}{p(\\underline{y} | X)}$$.\n",
"\n",
"\n",
" This expression is known as **Bayes Rule** and it yields our posterior distribution over parameters $p(\\color{blue}{\\underline{w}} | \\underline{y}, X)$. The posterior represents how well each parameter setting agrees with our prior and data. If we compute this distrbution, we will know what are good values of $\\color{blue}{\\underline{w}}$ with which we can make predictions.\n",
"\n",
"The procedure of computing the posterior is also known as **inference**.\n",
"\n",
"<font color='green'>`Fun fact`</font>: in large scale machine learning, \"inference\" is used to mean \"making predictions\". This is a miss-use of the word that stuck!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "brWSkmbarw0j"
},
"source": [
"**Math Task:** <font color='orange'>`Intermediate`</font> (Optional)\n",
"\n",
"Compute the posterior distribution $p(\\color{blue}{\\underline{w}} | \\underline{y}, X) = \\frac{p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}})}{p(\\underline{y} | X)}$ for our linear regression model\n",
"\n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) = \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I) = \\frac{1}{ (2 \\pi)^{N/2} } \\exp(\\frac{- (y - X \\cdot \\color{blue}{\\underline{w}})^T (y - X \\cdot \\color{blue}{\\underline{w}})}{2})$ \n",
"\n",
"$p(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I) = \\frac{1}{ (2 \\pi \\alpha)^{D/2} } \\exp(\\frac{-\\alpha \\color{blue}{\\underline{w}}^T \\color{blue}{\\underline{w}}}{2})$\n",
"\n",
"\\\n",
"* Hint 1: The product of Gaussian distributions gives us a Gaussian scaled by a constant $c$. \n",
"<center>\n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I) \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I) = c \\cdot \\mathcal{N}(\\color{blue}{\\underline{w}}; \\mu, \\Sigma)$\n",
"</center>\n",
"\n",
"$c$ is $p(\\underline{y} | X)^{-1}$ and is independent of $\\color{blue}{\\underline{w}}$ so we dont care about it. Your job is to find $\\mu$ and $\\Sigma$.\n",
"\n",
"\\\n",
"* Hint 2: Nottice that $\\color{blue}{\\underline{w}}$ only appears in the exponential terms of our distributions. This means that we can ignore the non-exponential terms because these will only affect $c$. Thus you need to only consider the expression\n",
"\n",
"<center>\n",
"$- (y - X \\cdot \\color{blue}{\\underline{w}})^T (y - X \\cdot \\color{blue}{\\underline{w}}) -\\alpha \\color{blue}{\\underline{w}}^T \\color{blue}{\\underline{w}} = - (\\color{blue}{\\underline{w}} - \\mu)^T \\Sigma^{-1} (\\color{blue}{\\underline{w}} - \\mu)$\n",
"</center>\n",
"\n",
"Try to solve for $\\Sigma$ first. If you get stuck, ask a tutor!\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "FeHnJXGFuvOM"
},
"outputs": [],
"source": [
"# @title Answer to math task (Try not to run until you've given it a good try!')\n",
"from IPython.display import display, Markdown, Latex\n",
"\n",
"display(\"Solution Derivation\")\n",
"display(Latex(\"1. Keep only quadratic terms to solve for $\\Sigma$\"))\n",
"display(\n",
" Latex(\n",
" \"$ - \\\\underline{w}^T X^T X \\\\underline{w} - \\\\alpha \\\\underline{w}^T \\\\underline{w} = - \\\\underline{w}^T \\Sigma^{-1} \\\\underline{w} $\"\n",
" )\n",
")\n",
"display(\"Thus\")\n",
"display(Latex(\"$\\\\Sigma = (X^T X + \\\\alpha I)^{-1}$\"))\n",
"display(Latex(\"2. Keep only first order terms to solve for $\\mu$\"))\n",
"display(Latex(\"$2 \\\\underline{w} X^T y = 2 \\\\underline{w} \\Sigma^{-1} \\mu$\"))\n",
"display(Latex(\"$\\mu = \\Sigma X^T y \"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "38mdDCSOFfpn"
},
"outputs": [],
"source": [
"#@title **coding task** <font color='green'>`base`</font> : implement a method that computes the posterior distribution for linear regression\n",
"\n",
"#@markdown Even if you did not solve the math task, the solution is\n",
"#@markdown $\\Sigma = (X^T X + \\alpha I)^{-1} \\quad \\mu = \\Sigma X^T y$ \n",
"\n",
"def BLR_posterior(X, y, alpha):\n",
" \"\"\"\n",
" Computes linear regression posterior parameters\n",
" \n",
" Args:\n",
" X: jnp array of shape (N,D)\n",
" y: jnp array of shape (N,)\n",
" alpha: scalar - regulasisation strength\n",
" Returns:\n",
" mu: jnp array of shape (D,)\n",
" covariance: jnp array of shape (D,D)\n",
" \"\"\"\n",
" D = X.shape[1] # D=2 parameters, one weight and 1 bias\n",
" covariance = # Your code goes here\n",
" mu = # Your code goes here\n",
" return mu, covariance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "fzW-EHxhxR9m"
},
"outputs": [],
"source": [
"# @title Run me to test your code\n",
"\n",
"alpha = 5\n",
"mu, covariance = BLR_posterior(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, alpha\n",
")\n",
"\n",
"mu_correct = jnp.array([-1.4181819, -1.101769])\n",
"covariance_correct = jnp.array(\n",
" [[9.090909e-02, 9.852008e-10], [9.852008e-10, 9.090909e-02]]\n",
")\n",
"\n",
"assert jnp.allclose(mu_correct, mu), \"mu is not calculated correctly\"\n",
"assert jnp.allclose(\n",
" covariance_correct, covariance\n",
"), \"Covariance is not calculated correctly\"\n",
"\n",
"print(\"It seems correct. Look at the answer below to compare methods.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Jz7ujBdKxSAO"
},
"outputs": [],
"source": [
"# @title Answer (Try not to peek until you've given it a good try!')\n",
"\n",
"\n",
"def BLR_posterior(X, y, alpha):\n",
" \"\"\"\n",
" Computes linear regression posterior parameters\n",
"\n",
" Args:\n",
" X: jnp array of shape (N,D)\n",
" y: jnp array of shape (N,)\n",
" alpha: scalar - regulasisation strength\n",
" Returns:\n",
" mu: jnp array of shape (D,)\n",
" covariance: jnp array of shape (D,D)\n",
" \"\"\"\n",
" D = X.shape[1] # D=2 parameters, one weight and 1 bias\n",
" precision = X.T @ X + alpha * jnp.eye(D)\n",
" covariance = jnp.linalg.inv(precision)\n",
" mu = solve(precision, (X.T @ y))\n",
" return mu, covariance"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UOaLJpKZBU2V"
},
"source": [
"### The posterior is just the loss, but normalised!\n",
"\n",
"Recall that the posterior is computed as $\\frac{p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}})}{p(\\underline{y} | X)}$ and the denominator is just a normalising constant independent of $\\color{blue}{\\underline{w}}$.\n",
"\n",
"Consider the log of the unnormalised posterior\n",
"\n",
"<center>\n",
"$\\log \\left( p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) \\right) = \\log \\left( \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I) \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)\\right)$\n",
"\n",
"\\\n",
"$= \\frac{- (y - X \\cdot \\color{blue}{\\underline{w}})^T (y - X \\cdot \\color{blue}{\\underline{w}})}{2} + \\frac{-\\alpha \\color{blue}{\\underline{w}}^T \\color{blue}{\\underline{w}}}{2} + C$\n",
"\n",
"\\\n",
"$= \\frac{-1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 + \\frac{-\\alpha}{2}||\\color{blue}{\\underline{w}}||_{2}^{2} + C$\n",
"\n",
"</center>\n",
"\n",
"where $C$ is a constant independent of $\\color{blue}{\\underline{w}}$. \n",
"\n",
"**We recover our regularised loss function!!** This can help us build intuition about the posterior density of $\\color{blue}{\\underline{w}}$: it is a function that assigns to each $\\color{blue}{\\underline{w}}$ a density which is larger the larger density that $\\color{blue}{\\underline{w}}$ has under our prior and the better that $\\color{blue}{\\underline{w}}$ fits the data.\n",
"\n",
"Since the log is a monotonic function, we have just derived that the weight settings that minimise our loss are the same ones that minimise our posterior! "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "hybIp3HW03kY"
},
"outputs": [],
"source": [
"# @title We can plot our log posterior distribution by plotting our loss landscape, since both are equal up to a constant. {run : 'auto'}\n",
"\n",
"# @markdown **Plotting task** <font color='green'>`Base`</font>: Try different values of $\\alpha$ to see how the prior and posterior change\n",
"\n",
"# @markdown Note: the concentric lines in the plot, also known as level-sets or contours, allow us to see how fast the value of a function decreases. If the contours are very close to each other, it means that the function decreases fast.\n",
"\n",
"alpha = 3 # @param {type:\"slider\", min:0.01, max:20, step:0.01\n",
"mu, covariance = BLR_posterior(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, alpha\n",
")\n",
"\n",
"\n",
"def generate_loss_fun():\n",
" def loss_fun(w):\n",
" return regularised_linear_regression_loss(\n",
" X=affine_basis(x_data_list_outlier),\n",
" y=y_data_list_outlier,\n",
" w=w,\n",
" vmap_model=vmap_linear,\n",
" alpha=alpha,\n",
" )\n",
"\n",
" return loss_fun\n",
"\n",
"\n",
"loss_fun = generate_loss_fun()\n",
"\n",
"x0_grid, x1_grid, loss_grid = generate_loss_grid(\n",
" loss_fun=loss_fun, grid_size=200, lim0=[-8, 8], lim1=[-8, 8]\n",
")\n",
"\n",
"\n",
"plt.figure(dpi=120)\n",
"plt.pcolormesh(\n",
" x0_grid, x1_grid, -loss_grid, vmin=-loss_grid.min(), vmax=-70, cmap=\"viridis\"\n",
") #\n",
"plt.ylabel(\"w\")\n",
"plt.xlabel(\"b\")\n",
"plt.grid(alpha=0.3)\n",
"cbar = plt.colorbar()\n",
"cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
"\n",
"ax = plt.gca()\n",
"plot_log_gaussian_ellipse(\n",
" ax=ax,\n",
" mean=mu,\n",
" cov=covariance,\n",
" color=\"r\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"Analytical posterior contours\",\n",
" MAP_size=10,\n",
" std_levels=[1, 2, 4, 6],\n",
")\n",
"plot_log_gaussian_ellipse(\n",
" ax=ax,\n",
" mean=jnp.array([0, 0.0]),\n",
" cov=(alpha**-1) * jnp.eye(2),\n",
" color=\"w\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"Analytical prior contours\",\n",
" MAP_size=10,\n",
" std_levels=[1, 2, 4, 6],\n",
")\n",
"plt.title(\"The loss landscape matches the log posterior landscape\")\n",
"plt.xlim([-8, 8])\n",
"plt.ylim([-8, 8])\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0YWFJuSF1ILp"
},
"source": [
"In the above plot, the Bayesian posterior computed with our method (red lines) corresponds exactly with the loss regions.\n",
"\n",
"Thus, Bayesian inference is just trying to find all of the parameter settings that have low loss, i.e. that fit both our regulariser and data well.\n",
"\n",
"Nottice that the posterior is always \"contained within\" the prior contours. One interpretation of the posterior is that it is the region of the prior with parameters that fit the data well. The tighter the prior, the tighter that the posterior will be.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uoIRolsuSNMI"
},
"source": [
"**Group task**: <font color='green'>`Base`</font> - Discuss with a neighbour or tutor:\n",
"\n",
"* How does Bayesian learning relate to traditional loss minimisation learning?\n",
"\n",
"* How does the Bayesian posterior relate to the loss landscape?\n",
"\n",
"* Can Bayesian learning be performed for any model for which traditional loss minimisation is used?\n",
"\n",
"* When the posterior is wider, what does it say about our knowledge of the parameters $\\color{blue}{\\underline{w}}$. What about when it is narrower?\n",
"\n",
"* For larger values of $\\alpha$ the posterior becomes narrower. What implications do you think this will have for our predictions?\n",
"\n",
"* Do you think that observing more data would make the posterior wider or narrower?\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EMe3v8Am63RF"
},
"source": [
"## Making predictions the Bayesian way!\n",
"\n",
"Once that we have our posterior distribution over $\\color{blue}{\\underline{w}}$, we are ready to make predictions.\n",
"\n",
"Instead of taking the optimum of the loss and just making predictions with that parameter setting, the Bayesian approach is to make predictions with all parameter settings and weight them by their density under the posterior. \n",
"\n",
"In other words, we go from a probability distribution over weights to a probability distribution over outputs.\n",
"\n",
"In our particular case, we want to push the weight distribution \n",
"\n",
"<center>\n",
"$\\mathcal{N}(\\color{blue}{\\underline{w}}; \\mu, \\Sigma)$\n",
"</center>\n",
"\n",
" through our function \n",
" <center>\n",
" $f(X, \\color{blue}{\\underline{w}}) = X \\color{blue}{\\underline{w}}$\n",
" </center>\n",
"\n",
"\n",
"A nice property of Gaussians is that linear transformations of Gaussian random variables are also Gaussian with mean $X \\mu $ and covariance $X \\Sigma X^T$.\n",
"\n",
"Thus $X\\cdot \\color{blue}{\\underline{w}} \\sim \\mathcal{N}(X \\cdot \\mu, X \\Sigma X^T)$\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7zrDX6aHGY3m"
},
"outputs": [],
"source": [
"# @title Exact predictive posterior for Bayesian Linear regression\n",
"\n",
"# @markdown Here is the code to implement this.\n",
"\n",
"# @markdown Instead of returning the full covariance, we will return the square root of the diagonal $\\sqrt{\\textrm{diag}(\\Sigma)}$. This is the predictive standard deviation. The std dev is nice for plotting because its scale matches that of the data.\n",
"\n",
"\n",
"def BLR_predictions(X, mu, covariance):\n",
" \"\"\"\n",
" Computes linear regression posterior predictive distribution.\n",
" Instead of returning the full covariance we will return the square root of the diagonal. This is the predictive standard deviation, which is nice for plotting.\n",
"\n",
" Args:\n",
" X: jnp array of shape (N,D)\n",
" mu: jnp array of shape (D,)\n",
" covariance: jnp array of shape (D,D)\n",
" Returns:\n",
" predictive_mean: jnp array of shape (N,)\n",
" predictive_std: jnp array of shape (N,)\n",
" \"\"\"\n",
" predictive_mean = X @ mu\n",
" predictive_std = jnp.diag(X @ covariance @ X.T) ** 0.5\n",
" return predictive_mean, predictive_std"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "NuzGQt1v7kTK"
},
"outputs": [],
"source": [
"# @title **Plotting task** <font color='green'>`Base`</font>: Lets look at the posterior predictive distribution for different values of $\\alpha$ {run : 'auto'}\n",
"\n",
"alpha = 3.0 # @param {type:\"slider\", min:0.01, max:20, step:0.01\n",
"mu, covariance = BLR_posterior(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, alpha\n",
")\n",
"\n",
"\n",
"plt.figure(dpi=100)\n",
"\n",
"xlim = [jnp.min(x_data_list_outlier) - 1, jnp.max(x_data_list_outlier) + 1]\n",
"ylim = [jnp.min(y_data_list_outlier) - 1, jnp.max(y_data_list_outlier) + 1]\n",
"\n",
"x_test = jnp.linspace(xlim[0], xlim[1], 201)\n",
"X_test = affine_basis(x_test)\n",
"\n",
"predictive_mean, predictive_std = BLR_predictions(X_test, mu, covariance)\n",
"\n",
"ax = plt.gca()\n",
"errorfill(\n",
" x_test,\n",
" predictive_mean,\n",
" predictive_std,\n",
" color=\"red\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax,\n",
" lw=2,\n",
" linestyle=\"-\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=None,\n",
" markevery=None,\n",
")\n",
"\n",
"\n",
"ax.plot(x_data_list_outlier, y_data_list_outlier, \"ob\")\n",
"ax.set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
"ax.grid(alpha=0.3)\n",
"ax.legend()\n",
"plt.title(\"predictive mean + 1 standard deviation errorbars on either side\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NLLP3H2T_DjK"
},
"source": [
"### Making predictions by sampling\n",
"\n",
"The uncertainty in our posterior distribution over $\\color{blue}{\\underline{w}}$ translates to uncertainty (errorbars) in our predictions.\n",
"\n",
"\n",
"\n",
"We can get more concrete intuition about how this happens, by drawing $K$ samples from our posterior distribution over parameters \n",
"\n",
"<center>\n",
"$ \\color{blue}{\\underline{w}}_{k} \\sim \\mathcal{N}(\\mu, \\Sigma)$\n",
"</center>\n",
"\n",
"and pushing them through our function $f$ to obtain $K$ different regression lines \n",
"\n",
"<center>\n",
"$X \\color{blue}{\\underline{w}}_{1}, X \\color{blue}{\\underline{w}}_{2}, ..., X\\color{blue}{\\underline{w}}_{K}$\n",
"</center> \n",
"\n",
"We can then estimate the predictive mean $\\mu_f$ and predictive standard deviation $\\sigma_f$ from these samples:\n",
"\n",
"<center>\n",
"$\\mu_f = X \\mu = \\mathbb{E}[X \\color{blue}{\\underline{w}}] \\approx \\frac{1}{K}\\sum_{k=1}^K X \\color{blue}{\\underline{w}}_{k}$\n",
"</center>\n",
"\n",
"\n",
"<center>\n",
"$\\sigma_f = \\sqrt{\\textrm{diag}(X \\Sigma X^T)} = \\sqrt{\\mathbb{E}[(X \\color{blue}{\\underline{w}} - \\mu_f)^2]} \\approx \\sqrt{\\frac{1}{K}\\sum_{k=1}^K (X \\color{blue}{\\underline{w}}_{k} - \\mu_f)^2}$\n",
"</center>\n",
"\n",
"\n",
"This procedure of estimating expectations as averages over samples is known as **Monte Carlo estimation**. Even if the exact expression for the predictive posterior were not available, we could always rely on Monte Carlo to make predictions (and we will do that in the next sections).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aFFwUkM9GY6G"
},
"outputs": [],
"source": [
"key = random.PRNGKey(0)\n",
"Nsamples = 10\n",
"\n",
"parameter_samples = jax.random.multivariate_normal(\n",
" key, mu, covariance, shape=(Nsamples,)\n",
")\n",
"sample_preds = vmap_linear(X_test, parameter_samples.T)\n",
"MC_mean = sample_preds.mean(axis=1)\n",
"MC_std = sample_preds.std(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "odW4YacV61dC"
},
"outputs": [],
"source": [
"# @title visualizing the Monte Carlo estimate of the predictive posterior\n",
"\n",
"fig, ax = plt.subplots(1, 3, dpi=200, figsize=(12, 4))\n",
"\n",
"\n",
"loss_landscape = ax[0].pcolormesh(\n",
" x0_grid,\n",
" x1_grid,\n",
" -loss_grid,\n",
" vmin=-loss_grid.min() * 5,\n",
" vmax=-loss_grid.min(),\n",
" cmap=\"viridis\",\n",
") #\n",
"ax[0].set_ylabel(\"w\")\n",
"ax[0].set_xlabel(\"b\")\n",
"ax[0].grid(alpha=0.3)\n",
"ax[0].set_title(\"Loss landscape, weight posterior and samples\")\n",
"\n",
"plot_log_gaussian_ellipse(\n",
" ax=ax[0],\n",
" mean=mu,\n",
" cov=covariance,\n",
" color=\"r\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"posterior contours\",\n",
" MAP_size=10,\n",
" std_levels=[1, 2, 4, 6],\n",
")\n",
"\n",
"ax[0].scatter(\n",
" parameter_samples[:, 0], parameter_samples[:, 1], 8, color=\"green\", label=\"samples\"\n",
")\n",
"cbar = fig.colorbar(loss_landscape, ax=ax[0])\n",
"cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
"ax[0].legend()\n",
"\n",
"\n",
"xlim = [jnp.min(x_data_list_outlier) - 1, jnp.max(x_data_list_outlier) + 1]\n",
"ylim = [jnp.min(y_data_list_outlier) - 1, jnp.max(y_data_list_outlier) + 1]\n",
"\n",
"x_test = jnp.linspace(xlim[0], xlim[1], 201)\n",
"X_test = affine_basis(x_test)\n",
"\n",
"errorfill(\n",
" x_test,\n",
" predictive_mean,\n",
" predictive_std,\n",
" color=\"red\",\n",
" alpha_fill=0.1,\n",
" line_alpha=1,\n",
" ax=ax[1],\n",
" lw=1,\n",
" linestyle=\"-\",\n",
" fill_linewidths=2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"exact predictive\",\n",
" markevery=None,\n",
")\n",
"\n",
"\n",
"ax[1].plot(x_data_list_outlier, y_data_list_outlier, \"ob\")\n",
"ax[1].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
"ax[1].grid(alpha=0.3)\n",
"ax[1].legend()\n",
"ax[1].set_title(f\"exact predictive posterior\")\n",
"\n",
"ax[2].plot(x_test, sample_preds, \"-\", color=\"green\", lw=0.3, alpha=0.9)\n",
"errorfill(\n",
" x_test,\n",
" MC_mean,\n",
" MC_std,\n",
" color=\"green\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax[2],\n",
" lw=1,\n",
" linestyle=\"--\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"sample estimate\",\n",
" markevery=None,\n",
")\n",
"\n",
"\n",
"ax[2].plot(x_data_list_outlier, y_data_list_outlier, \"ob\")\n",
"ax[2].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
"ax[2].grid(alpha=0.3)\n",
"ax[2].legend()\n",
"ax[2].set_title(f\" MC estimate from {Nsamples} samples\")\n",
"\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z7vPRsmDHulk"
},
"source": [
"**Group task** <font color='green'>`Base`</font> - Discuss with a neighbour:\n",
"\n",
"* Is our sample based estimate of the predictive distribution good? Do we have enough samples?\n",
"* Why are the errorbars of the predictive posterior less wide near the middle of the plot data?\n",
"* How would the errorbars change if we added more data (observations)? What about if we decreased the number of observations? What about if we increased the regularisation strength?\n",
"* What observation (as in [x, y] value pair) would we have to make to maximally decrease the errorbar size?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hFb8e19EIx4T"
},
"source": [
"## Takeaways\n",
"\n",
"* Bayesian learning specifies our prior beleifs about what our function will look like explicitly through the choice of a prior distribution. Bayesian learning uses Bayes rule to update this prior distribution with our data. Specifically, out of all the parameter settings allowed by our prior, we only keep the ones that have a high likelihood (that fit the data well).\n",
"\n",
"* The Bayesian posterior distribution is just the regularised loss normalised by a constant that makes it integrate to 1. Thus, weight settings that have large posterior density are those that have low loss.\n",
"\n",
"* We make predictions by combining the predictions from every possible parameter setting, weighed by their respective posterior densities. Disagreement among the regression lines from these parameter settings induces uncertainty in the predictions. When exact computations are not tractable, we can estimate the predictive distribution with Monte Carlo.\n",
"\n",
"## End of Section\n",
"\n",
"The rest of the section contains more advanced optional contents. You can skip to [Section 2](#S2) if you want"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j2x7Nw7SBkdJ"
},
"source": [
"## With a more sophisticated basis linear regression can be very powerfull (Optional) - <font color='orange'>`Intermediate`</font>\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "BH3D-TZSBoE9"
},
"outputs": [],
"source": [
"# @markdown Lets introduce a dataset that requires a non-linear model.\n",
"x_data_list_nonlinear = jnp.array([-2, -1, 0, 1, 1.1, 3.5, 4, 5])\n",
"y_data_list_nonlinear = jnp.array([3.2, 3.3, 3.2, 1.3, 1.6, 1.1, 2.0, 3.5]) * 5\n",
"\n",
"plot_basic_data(x_data_list_nonlinear, y_data_list_nonlinear)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rwTi88eycFn6"
},
"source": [
"### A non-linear basis\n",
"\n",
"\n",
"So far we have used the affine basis $x \\to [1, x]$. This allows us to generate lines with a learnt slope and offset. \n",
"\n",
"Now consider the basis $x \\to [\\cos(s_{0} x + u_{0}), \\cos(s_{1} x + u_{1}), \\,..., \\cos(s_{D} x + u_{D})]$ where $\\,\\,s_{i} \\sim \\mathcal{N}(0, \\sigma_s^{2})$ and $\\,\\,u_{i} \\sim U(-\\pi, \\pi)$.\n",
"\n",
"This is known as \"random fourier\" basis and given a large enough number of basis functions $D$, it will allow us to model any function. See the paper [Random Features for Large-Scale Kernel Machines](https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) for more details on this basis.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xnwN2Lxm1mFP"
},
"outputs": [],
"source": [
"def generate_cosine_basis(key, N_inputs, N_elements, s_std=1):\n",
"\n",
" key, _ = random.split(key)\n",
" S = random.normal(key, shape=(N_inputs, N_elements)) * s_std\n",
" key, _ = random.split(key)\n",
" u = (random.uniform(key, shape=(N_elements,)) - 0.5) * 2 * jnp.pi\n",
"\n",
" def basis_expand(x):\n",
" if x.ndim == 1:\n",
" x = x.copy()[:, None]\n",
" return jnp.cos(x @ S + u) / np.sqrt(N_elements)\n",
"\n",
" return basis_expand"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "llz38Em7e5X6"
},
"outputs": [],
"source": [
"alpha = 0.001\n",
"std_s = 0.75"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "l2aZzsK11mLv"
},
"outputs": [],
"source": [
"# @title Run regression using the code from this section but with the affine basis\n",
"\n",
"# @markdown As before, the thick red line is the predictive mean, the transparent bars represent the standard deviation. The lin lines represent posterior samples.\n",
"\n",
"N_inputs = 1\n",
"N_elements = 500\n",
"\n",
"key = random.PRNGKey(42)\n",
"\n",
"\n",
"def gen_plots(params):\n",
"\n",
" xlim = [-4, 7]\n",
"\n",
" alpha = params[1]\n",
" fourier_basis = generate_cosine_basis(key, N_inputs, N_elements, s_std=params[0])\n",
"\n",
" X = fourier_basis(x_data_list_nonlinear)\n",
" mu, covariance = BLR_posterior(X, y_data_list_nonlinear, alpha)\n",
"\n",
" x_pred = jnp.linspace(xlim[0], xlim[1], 101)\n",
" X_pred = fourier_basis(x_pred)\n",
"\n",
" parameter_samples = jax.random.multivariate_normal(\n",
" key, mu, covariance, shape=(Nsamples,)\n",
" )\n",
" sample_preds = vmap_linear(X_pred, parameter_samples.T)\n",
"\n",
" predictive_mean, predictive_std = BLR_predictions(X_pred, mu, covariance)\n",
"\n",
" #################\n",
"\n",
" fig, ax = plt.subplots()\n",
" ylim = [-15, 30]\n",
" ax.plot(x_pred, sample_preds, \"-\", color=\"red\", lw=0.3)\n",
"\n",
" errorfill(\n",
" x_pred,\n",
" predictive_mean,\n",
" predictive_std,\n",
" color=\"red\",\n",
" alpha_fill=0.15,\n",
" line_alpha=1,\n",
" ax=ax,\n",
" lw=2,\n",
" linestyle=\"-\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=None,\n",
" markevery=None,\n",
" )\n",
"\n",
" ax.plot(x_data_list_nonlinear, y_data_list_nonlinear, \"ob\")\n",
" ax.set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
" ax.grid(alpha=0.3)\n",
"\n",
"\n",
"gen_plots([std_s, alpha])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k9pYt0CbHlu6"
},
"source": [
"**Group task** (Optional) <font color='orange'>`Intermediate`</font>: - Discuss with a neighbour:\n",
"\n",
"* Why do the errorbars grow so large in the [1, 3] range and outside the x = [-1, 5] range?\n",
"* What effect does $\\alpha$ have? What about the random frequency standard deviation $\\sigma$? "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uAbUupw9b_v1"
},
"source": [
"## The model evidence for hyperparameter selection (Optional) - <font color='red'>`Advanced`</font>\n",
"\n",
"\n",
"Recall Bayes rule $\\frac{p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}})}{p(\\underline{y} | X)}$. In the beginning of the section, we discounted the normalising constant $p(\\underline{y} | X) $ as not important due to it being constant in $\\color{blue}{\\underline{w}}$.\n",
"\n",
"However, this term, which is known as the **Model evidence** or **Marginal likelihood** (both expressions mean exactly the same thing), can be very useful. To see this consider the expanded expression:\n",
"\n",
"<center>\n",
"$p(\\underline{y} | X) = \\int p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) d \\color{blue}{\\underline{w}}$\n",
"</center>\n",
"\n",
"it is the integral of the likelihood $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$, which measures data fit, against the prior density $p(\\color{blue}{\\underline{w}})$. This expression is very similar to that of the predictive posterior, except that the distribution we integrate agaisnt is the prior and not the posterior. \n",
"\n",
"**The marginal likelihood tells us how well we fit the data with the prior. It can be used as a learning objective for prior hyperparameters**, such as the random fourier coefficient variance $\\sigma^2_s$ or prior precision $\\alpha$. When we use the marginal likelihood, our weights $\\color{blue}{\\underline{w}}$ are integrated out -- the only free parameters are those of our prior. Since our prior will usually have a small number of parameters, we can usually fit these safely using all of our data without fear of overfitting. \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Myv6WRrPNP8"
},
"source": [
"In the Gaussian linear model, we compute the marginal likelihood as \n",
"\n",
"<center>\n",
"$p(\\underline{y} | X) = \\int \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I) \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I) d \\color{blue}{\\underline{w}}$\n",
"</center>\n",
"\n",
"You can try to solve this integral as a take-home excercise ([this](https://www.utstat.utoronto.ca/~radford/sta414.S11/week4a.pdf) is a nice resource to help you). To save time, we provide the solution here. In practise we work with logs for numerical stability.\n",
"\n",
"$\\log p(\\underline{y} | X)$ = $-\\frac{1}{2}||\\underline{y} - X\\cdot \\mu||_{2}^2 - \\frac{\\alpha}{2}||\\mu||_{2}^{2} + \\frac{1}{2} \\textrm{logdet}(\\Sigma) + \\frac{D}{2} \\log \\alpha - \\frac{1}{2} N \\log(2 \\pi)$\n",
"\n",
"\n",
"The next cell contains an implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ez6pxY1fvsrt"
},
"outputs": [],
"source": [
"def linear_evidence(x_data_list, y_data_list, basis, alpha):\n",
" X = basis(x_data_list)\n",
" mu, covariance = BLR_posterior(X, y_data_list, alpha)\n",
" loss = regularised_linear_regression_loss(\n",
" X=X, y=y_data_list, w=mu, vmap_model=vmap_linear, alpha=alpha\n",
" )\n",
" a, log_det = jnp.linalg.slogdet(covariance)\n",
" return (\n",
" -loss\n",
" + 0.5 * (len(mu) * jnp.log(alpha) + log_det)\n",
" - 0.5 * len(x_data_list) * jnp.log(2 * jnp.pi)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7k_HJq1GLcR0"
},
"outputs": [],
"source": [
"# @markdown Lets compute the evidence of our affine model and random fourier model on the new non-linear data\n",
"\n",
"alpha = 0.01\n",
"s_std = 0.75\n",
"fourier_basis = generate_cosine_basis(key, N_inputs, N_elements, s_std)\n",
"fourier_evidence = linear_evidence(\n",
" x_data_list_nonlinear, y_data_list_nonlinear, fourier_basis, alpha\n",
")\n",
"affine_evidence = linear_evidence(\n",
" x_data_list_nonlinear, y_data_list_nonlinear, affine_basis, alpha\n",
")\n",
"print(\"fourier_evidence\", fourier_evidence)\n",
"print(\"affine_evidence\", affine_evidence)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LzfroHOt2iX3"
},
"source": [
"**Some questions** <font color='red'>`Advanced`</font>\n",
"* Why is the fourier model's evidence larger than the affine model's evidence?\n",
"* Is the same true for the other dataset? (x_data_list_outlier, y_data_list_outlier) -- you might need to tune alpha manually for each model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IxkKK5791DQn"
},
"outputs": [],
"source": [
"# @title **Plotting Task** (Optional) <font color='red'>`Advanced`</font>: Optimise the fourier basis standard deviation $\\sigma$ and the regularisation strength $\\alpha$ with the linear evidence objective\n",
"\n",
"\n",
"def gen_objective(x_data_list, y_data_list, key, N_elements):\n",
" def objective(params, key):\n",
" basis = generate_cosine_basis(key, 1, N_elements, s_std=params[0])\n",
" return linear_evidence(x_data_list, y_data_list, basis, params[1])\n",
"\n",
" return objective\n",
"\n",
"\n",
"N_elements = 300\n",
"objective = gen_objective(x_data_list_nonlinear, y_data_list_nonlinear, key, N_elements)\n",
"\n",
"optimised_regularised_classification_params = optimise(\n",
" objective,\n",
" params=jnp.array([2, 0.001]),\n",
" plotting_func=gen_plots,\n",
" LR=2e-3,\n",
" MAX_STEPS=1300,\n",
" LOG_EVERY=100,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K56NMfdP3ReF"
},
"source": [
"**Some more questions** <font color='red'>`Advanced`</font>\n",
"* Is there any overfitting happening? \n",
"* Do the errorbars become bigger or smaller? Why?\n",
"* When is it safe to use the model evidence as a training objective for our hyperparameters?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cOGuGWtLmP7n"
},
"source": [
"# Section 2 - Logistic Regression -- the need for approximate inference <a name=\"S2\"></a>\n",
"\n",
"Linear regression was nice because the posterior distribution over the parameters and the predictive posterior could be written down in closed form.\n",
"We are now ready to exit the world of exact solutions which you can compute with pen and paper. This is a short section meant as motivation for the rest of the prac. Spend no more than 10 minutes on it.\n",
"\n",
"We will start by considering logistic regression. We are going to deal with a classification task where our targets are either 0 or 1. Data that takes these values can be described by a [Bernouilli distribution](https://en.wikipedia.org/wiki/Bernoulli_distribution)\n",
"\n",
"<center>\n",
"$y_{n} \\sim \\text{Bern}(\\rho)$\n",
"</center>\n",
"\n",
"where $\\rho \\in [0,1]$ is the probability of our observations taking a value of 1. Our model will predict this probability for each of the inputs: $\\rho = f(x, \\color{blue}{\\underline{w}})$.\n",
"\n",
"We will use the same affine linear model as before, but we need to restrict the output to the $[0,1]$ range. For this, we use the sigmoid function $\\phi(f) = \\frac{1}{1 + \\exp(-f)}$.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vMLBmsEjWz8B"
},
"source": [
"The final piece of the puzzle is introducing the standard assumption that our target labels $\\underline{y} | X, \\color{blue}{\\underline{w}}$ are [independent and identically distributed](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables) (i.i.d) given our \n",
"inputs and weights. This assumption is going to allow us to write the likelihood as a product of the likelihoods from individual datapoints $p(\\underline{y} | X, \\color{blue}{\\underline{w}}) = \\prod_{n=1}^N p(y_n | x_n, \\color{blue}{\\underline{w}})$. \n",
"\n",
"We were also implicitly making this assumption in the previous section when choose the covariance of our likelihood function to be the identity. We just did not write out the likelihood as a product to keep the notation compact.\n",
"\n",
"\n",
"Putting it all together we have \n",
"\n",
"<center>\n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) = \\prod_{n=1}^N Bern(y_{n}; \\phi(x_{n}\\color{blue}{\\underline{w}}))$\n",
"\n",
"and as before we keep the Gaussian prior over our parameters\n",
"\\\n",
"$p(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)$\n",
"</center>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tj9IEgVzAmLI"
},
"outputs": [],
"source": [
"# @title Dataset for logistic regression.\n",
"# @markdown the task will be to classify points with the labels 0 or 1\n",
"\n",
"x_data_list_classification = jnp.array([-2.7, -1.9, -1.3, 0.7, 0.8, 1.4, 2.4])\n",
"y_data_list_classification = jnp.array([0, 0, 0, 1, 1, 1, 1])\n",
"\n",
"\n",
"plot_basic_data(x_data_list_classification, y_data_list_classification)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "83JVL1LU7aJr"
},
"outputs": [],
"source": [
"# @title Define the model, same as above but with sigmoid\n",
"\n",
"\n",
"def logistic(w, x):\n",
" return sigmoid(linear(w, x))\n",
"\n",
"\n",
"vmap_logistic = jit(vmap(logistic, in_axes=(None, 0)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M1dFiIa8EE--"
},
"source": [
"## Loss landscape for classification\n",
"\n",
"\n",
"Similarly to linear reression, lets write out the unnormalised posterior to get a loss function to optimise. \n",
"\n",
"<center>\n",
"$\\log \\left( p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) \\right) = \\log \\left( \\prod_{n=1}^N Bern(y_n ; \\phi(x_n \\color{blue}{\\underline{w}})) \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)\\right)$\n",
"</center>\n",
"\n",
"\\\n",
"Now we note that the Bernouilli distribution is defined over discrete outcomes. In the previos sections, we dealt with Gaussian targets that could take any value in $\\mathcal{R}$ and thus we dealt with probability density functions. Now our targets can only take values $\\{0,1\\}$ and thus we can assign each of these outcomes a **probability mass**. The Bernouilli probability mass function is $Bern(y ; \\rho) = \\rho^y (1-\\rho)^{(1-y)}$. When we take its log we recover the cross entropy loss: $y \\log \\rho + (1-y) \\log (1-\\rho)$\n",
"\n",
"\n",
"\\\n",
"Dropping constants, the total loss is thus going to be\n",
"\n",
"<center>\n",
"$\\sum_{n=1}^N y_n \\log \\phi(x_n \\color{blue}{\\underline{w}}) + (1-y_n) \\log (1-\\phi(x_n \\color{blue}{\\underline{w}})) + \\frac{\\alpha}{2} ||\\color{blue}{\\underline{w}}||_{2}^2$\n",
"</center>\n",
"\n",
"**i.e. the regularised cross entropy loss falls right out of our Bayesian model!!**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zlLfPAlAEM_a"
},
"outputs": [],
"source": [
"# @title Implementation of regularised cross entropy loss\n",
"# @markdown nottice that we do not code the sigmoid since it is implicit in the `sigmoid_binary_cross_entropy` method\n",
"\n",
"# @markdown combining sigmoid and cross entropy in this way provides increased numerical stability\n",
"\n",
"\n",
"def logistic_regression_loss(X, y, vmap_model, w, alpha):\n",
" data_fit_loss = sigmoid_binary_cross_entropy(vmap_model(X, w), y).sum(\n",
" axis=0\n",
" ) # optax loss has sigmoid integrated so we just use linear model\n",
" parameter_norm = 0.5 * alpha * (w**2).sum()\n",
" return data_fit_loss + parameter_norm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8VsT7hjFBtA8"
},
"outputs": [],
"source": [
"# @title Plotting the loss landscape and finding the solution at the mode\n",
"# @markdown Nottice that the loss is non-symmetric around its mode\n",
"\n",
"alpha = 0.75\n",
"\n",
"\n",
"def generate_loss_fun():\n",
" def loss_fun(w):\n",
" return logistic_regression_loss(\n",
" X=affine_basis(x_data_list_classification),\n",
" y=y_data_list_classification,\n",
" vmap_model=vmap_linear,\n",
" w=w,\n",
" alpha=alpha,\n",
" )\n",
"\n",
" return loss_fun\n",
"\n",
"\n",
"loss_fun = generate_loss_fun()\n",
"\n",
"x0_grid, x1_grid, loss_grid = generate_loss_grid(\n",
" loss_fun=loss_fun, grid_size=200, lim0=[-8, 8], lim1=[-6, 10]\n",
")\n",
"\n",
"\n",
"def gen_objective(X, y, vmap_model, alpha):\n",
" def objective(params, key):\n",
" return -logistic_regression_loss(X, y, vmap_model, params, alpha)\n",
"\n",
" return objective\n",
"\n",
"\n",
"objective = gen_objective(\n",
" affine_basis(x_data_list_classification),\n",
" y_data_list_classification,\n",
" vmap_linear,\n",
" alpha,\n",
")\n",
"\n",
"optimised_regularised_classification_params = optimise(\n",
" objective,\n",
" params=jnp.array([0.0, 0.0]),\n",
" plotting_func=None,\n",
" LR=1e-2,\n",
" MAX_STEPS=800,\n",
" LOG_EVERY=100,\n",
")\n",
"\n",
"\n",
"fig, ax = plt.subplots(1, 2, dpi=140, figsize=(10, 4))\n",
"mesh = ax[0].pcolormesh(\n",
" x0_grid,\n",
" x1_grid,\n",
" -loss_grid,\n",
" vmin=-loss_grid.min() * 5,\n",
" vmax=-loss_grid.min(),\n",
" cmap=\"viridis\",\n",
") #\n",
"cbar = fig.colorbar(mesh, ax=ax[0])\n",
"cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
"ax[0].set_ylabel(\"w\")\n",
"ax[0].set_xlabel(\"b\")\n",
"ax[0].set_title(\"Loss / posterior landscape for classification\")\n",
"ax[0].grid(alpha=0.3)\n",
"ax[0].scatter(\n",
" optimised_regularised_classification_params[0],\n",
" optimised_regularised_classification_params[1],\n",
" c=\"cyan\",\n",
" label=\"mode\",\n",
")\n",
"ax[0].legend()\n",
"\n",
"\n",
"xlim = [-4, 4]\n",
"ylim = [-0.1, 1.1]\n",
"\n",
"x_pred = jnp.linspace(xlim[0], xlim[1], 101)\n",
"\n",
"\n",
"mode_predictions = sigmoid(\n",
" vmap_linear(affine_basis(x_pred), optimised_regularised_classification_params)\n",
")\n",
"ax[1].plot(x_pred, mode_predictions, c=\"cyan\", label=\"Best fit class probability\")\n",
"ax[1].grid(alpha=0.3)\n",
"ax[1].legend()\n",
"ax[1].plot(x_data_list_classification, y_data_list_classification, \"ob\")\n",
"ax[1].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OATtmRZEA5Lr"
},
"source": [
"**Group task <font color='green'>`Base`</font>** - Discuss with a neighbour and / or tutor:\n",
"\n",
"* Is the loss quadratic? Is the loss convex?\n",
"\n",
"* What is the cause of the posterior being wider for larger values of $w$?\n",
"\n",
"* When normalised, will the posterior be Gaussian?\n",
"\n",
"* Nottice that the Gaussian posterior contours from the previous section are symmetric around their mode. Is this loss landscape symmmetric around its mode? What do you think would be the implications of approximating this posterior distribution with a Gaussian.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "wjCALHKMavr7"
},
"outputs": [],
"source": [
"# @title Takeaways for this section (reveal after discussion above):\n",
"\n",
"print(\n",
" \"-- Outside of the linear-Gaussian regresion case, the loss function is usually non quadratic. This is the case for logistic regression.\\n\"\n",
")\n",
"\n",
"print(\"-- When the loss is non-quadratic, the posterior is non-Gaussian. \\n\")\n",
"\n",
"print(\n",
" '-- Usually, we do not have closed form \"pen and paper\" expressions for non-Gaussian posterior distributions \\n and we need to resort to approximations.\\n'\n",
")\n",
"\n",
"print(\n",
" \"-- If we approximate a non-Gaussian posterior with a Gaussian, we will be placing excess posterior probability density \\nin some regions of parameter space and not enough density in other regions. In the next sections we will see what effect this has on performance.\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPIlyvQNEMTb"
},
"source": [
"# Section 3: Black Box Variational Inference <a name=\"S3\"></a>\n",
"\n",
"This section implemented Black Box Variational Inference. It is the crux of the prac and you should spend up to 30 minutes on it.\n",
"\n",
"The logistic regression posterior is non-quadratic and thus non-Gaussian. Even worse, we do not have a closed form expression for the posterior distribution. However, after [Section 1](#S1), we are very comfortable working with Gaussians.\n",
"\n",
"We are going to try to find the Gaussian which is closest to the logistic regression posterior and use that one. We are going to call our Gaussian approximate distribution $q$. We will learn its mean and covariance parameters $\\mu_q, \\Sigma_q$: $q(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; \\mu_q, \\Sigma_q)$\n",
"\n",
"But, what do we mean by \"closest\"? We need to choose a notion of distance between distributions. We are going to choose the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence). That is, we want to minimise the KL divergence between our approximation $q(\\color{blue}{\\underline{w}})$ and the true posterior $p(\\color{blue}{\\underline{w}} | X, \\underline{y})$ which we write as $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}} | X, \\underline{y}))$. This quantity takes a minimum value of 0 which is only reached when $q(\\color{blue}{\\underline{w}}) = p(\\color{blue}{\\underline{w}} | X, \\underline{y})$. \n",
"\n",
"\\\n",
"<font color='green'>`Fun fact`</font>: Strictly speaking, the KL divergence is not a distance because it is not symmetric, meaning in general $KL(q \\,||\\, p) \\neq KL(p \\,||\\, q)$. As the name implies, it is actually a **divergence**, a type of function that does require symmetry. \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0H0gTiTlGoB8"
},
"source": [
"## Introducing the ELBO\n",
"\n",
"We dont know the anaylitical form of $p(\\color{blue}{\\underline{w}} | X, \\underline{y})$, so how do we minimise $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}} | X, \\underline{y}))$?\n",
"\n",
"We are going to use a trick called the **Evidence Lower BOund (ELBO)**. Specifically, we are going to derive that maximising the objective \n",
"\n",
"<center>\n",
"$\\text{ELBO} = \\mathbb{E}_{q(\\color{blue}{\\underline{w}})}[\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}})]$ - $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}}))$\n",
"</center>\n",
"\n",
"is equivalent to minimising $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}} | X, \\underline{y}))$. This is great because the ELBO does not contain the posterior anymore, only the likelihood $p(\\underline{y} | X, \\color{blue}{\\underline{w}})$ and prior $p(\\color{blue}{\\underline{w}})$!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mGGDtq7wISOI"
},
"source": [
"**Math Task** (Optional) <font color='red'>`Advanced`</font>: Derive the ELBO\n",
"\n",
"\n",
"<center>\n",
"Show that minimising $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}} | X, \\underline{y}))$ is equivalent to maximising $\\mathbb{E}_{q(\\color{blue}{\\underline{w}})}[\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}})]$ - $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}}))$\n",
"</center>\n",
"\n",
"\\\n",
"* Hint 1: The KL divergence can be written as $\\text{KL}\\left(q(a)\\,||\\,p(b)\\right) = \\mathbb{E}_{q(a)}[\\log q(a) - \\log p(b)]$. \n",
"\n",
"Use this together with decomposing the posterior using Bayes Rule: $p(\\color{blue}{\\underline{w}} | X, \\underline{y}) = \\frac{p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}})}{p(\\underline{y} | X)}$\n",
"\n",
"\\\n",
"* Hint 2: The ELBO is called like that because it is a lower bound on the evidence: $\\log p(\\underline{y} | X)$. Try separating out this term, and notticing that it does not depend on $\\color{blue}{\\underline{w}}$ so that $\\mathbb{E}_{q(\\color{blue}{\\underline{w}})}[\\log p(\\underline{y} | X)] = \\log p(\\underline{y} | X)$.\n",
"\n",
"To set up the inequality $\\log p(\\underline{y} | X) \\geq ...$ you can use [Jensen's](https://en.wikipedia.org/wiki/Jensen%27s_inequality) inquality: $\\quad \\log \\mathbb{E}[p] \\geq \\mathbb{E}[\\log p]$.\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "HLzvY-yWeGYq"
},
"outputs": [],
"source": [
"# @title Answer to math task (Try not to run until you've given it a good try!')\n",
"from IPython.display import display, Markdown, Latex\n",
"\n",
"display(\"Solution Derivation\")\n",
"display(\n",
" Latex(\n",
" \"$\\\\text{KL}(q(\\\\underline{w}) \\, ||\\, p(\\\\underline{w} | X, \\\\underline{y})) = \\\\mathbb{E}_{q(\\\\underline{w})}[\\\\log q(\\\\underline{w}) - \\\\log p(\\\\underline{w} | X, \\\\underline{y})] = \\mathbb{E}_{q(\\\\underline{w})}[\\\\log q(\\\\underline{w}) - \\\\log \\\\frac{p(\\\\underline{y} | X, \\\\underline{w}) p(\\\\underline{w})}{p(\\\\underline{y} | X)}]$\"\n",
" )\n",
")\n",
"\n",
"\n",
"display(\n",
" Latex(\n",
" \"$ = \\\\mathbb{E}_{q(\\\\underline{w})}[\\\\log q(\\\\underline{w}) - \\\\log p(\\\\underline{y} | X, \\\\underline{w}) - \\\\log p(\\\\underline{w}) + \\\\log p(\\\\underline{y} | X)]$\"\n",
" )\n",
")\n",
"display(\n",
" Latex(\n",
" \"$ = \\\\mathbb{E}_{q(\\\\underline{w})}[- \\\\log p(\\\\underline{y} | X, \\\\underline{w})] + \\\\mathbb{E}_{q(\\\\underline{w})}[ \\\\log q(\\\\underline{w}) - \\\\log p(\\\\underline{w})] + \\\\log p(\\\\underline{y} | X)$\"\n",
" )\n",
")\n",
"display(\n",
" Latex(\n",
" \"$ = - \\\\mathbb{E}_{q(\\\\underline{w})}[\\\\log p(\\\\underline{y} | X, \\\\underline{w})] + \\\\text{KL}\\\\left(\\\\log q(\\\\underline{w}) \\,||\\, \\\\log p(\\\\underline{w})\\\\right) + \\\\log p(\\\\underline{y} | X)$\"\n",
" )\n",
")\n",
"\n",
"\n",
"display(\"Thus\")\n",
"display(\n",
" Latex(\n",
" \"$\\\\log p(\\\\underline{y} | X) - \\\\text{KL}\\\\left(q(\\\\underline{w}) \\, ||\\, p(\\\\underline{w} | X, \\\\underline{y})\\\\right) = \\\\mathbb{E}_{q(\\\\underline{w})}[\\\\log p(\\\\underline{y} | X, \\\\underline{w})] - \\\\text{KL}\\\\left(\\\\log q(\\\\underline{w}) \\,||\\, \\\\log p(\\\\underline{w})\\\\right) = \\\\text{ELBO}$\"\n",
" )\n",
")\n",
"\n",
"display(\"Since\")\n",
"display(Latex(\"$\\\\log p(\\\\underline{y} | X)$\"))\n",
"display(\n",
" \"is constant and independent of q, then maximising the ELBO is equivalent to minimising\"\n",
")\n",
"display(\n",
" Latex(\n",
" \"$\\\\text{KL}\\\\left(q(\\\\underline{w}) \\, ||\\, p(\\\\underline{w} | X, \\\\underline{y})\\\\right)$\"\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fp-5CgX1Ip7p"
},
"source": [
"**Group task <font color='green'>`Base`</font>** - Discuss with a neighbour and / or tutor:\n",
"\n",
"* When will the KL to the true posterior $\\text{KL}(q(\\color{blue}{\\underline{w}}) \\, ||\\, p(\\color{blue}{\\underline{w}} | X, \\underline{y}))$ be 0, if ever?\n",
"\n",
"* What is the maximum value that the ELBO can take?\n",
"\n",
"* Can one overfit by optimising the ELBO?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "syl0y6UjLZmr"
},
"source": [
"## Implementing the ELBO\n",
"\n",
"Now we are going to implement an ELBO which we can apply to any model to perform Bayesian inference\n",
"\n",
"Most of the code is written for you and you will just have to complete some bits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cED67Bzn3We5"
},
"source": [
"### We are going to be learning a Gassuain $q(\\color{blue}{\\underline{w}}) = \\mathcal{N}(\\color{blue}{\\underline{w}}; \\mu_q, \\Sigma_{q})$\n",
"\n",
"In practise, that means that we need to learn the mean vector $\\mu_q$ and a covariance matrix $\\Sigma_q$. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wubw_EfhMuSQ"
},
"outputs": [],
"source": [
"# @title Code to initialise parameters and construct the covariance matrix\n",
"\n",
"# @markdown A property of covariance matrices is that they need to be [possitive definite](https://en.wikipedia.org/wiki/Definite_matrix)\n",
"\n",
"# @markdown To ensure that is always true, we are going to parametrise the covariance as $\\Sigma_{q} = (L_o + \\sigma I) (L_o + \\sigma I)^T$ where $L$ is a lower triangular matrix\n",
"\n",
"# @markdown **code review task** <font color='green'>`Base`</font>: read through this code to convince yourself it computes the covariance as described\n",
"\n",
"\n",
"def initialize_params(D, key=random.PRNGKey(0)):\n",
" \"\"\"\n",
" Return randomly initialised parameters of the variational posterior q\n",
" \"\"\"\n",
" dist_mean = normal(stddev=1)\n",
" dist_cov = normal(stddev=1e-3)\n",
" keys = random.split(key, 5)\n",
"\n",
" params = {}\n",
" params[\"w\"] = dist_mean(keys[0], (D,))\n",
" params[\"L_o\"] = jnp.tril(dist_cov(keys[1], (D, D)), -1)\n",
" params[\"sig\"] = dist_cov(keys[2], (D,))\n",
" return params\n",
"\n",
"\n",
"@jit\n",
"def get_L(params, min_diag=1e-6):\n",
" \"\"\"\n",
" Construct square root of covariance matrix as L_o + \\sigma I\n",
" \"\"\"\n",
" log_diag = params[\"sig\"]\n",
" off_diag = params[\"L_o\"]\n",
" sig_diag = jnp.diag(jnp.clip(jnp.exp(log_diag), a_min=min_diag))\n",
" tril_L = jnp.tril(off_diag, -1)\n",
" L = tril_L + sig_diag\n",
" return L\n",
"\n",
"\n",
"@jit\n",
"def get_Sig(params, min_diag=1e-6):\n",
" \"\"\"\n",
" Construct covariance matrix as (L + \\sigma I)(L + \\sigma I)^T\n",
" \"\"\"\n",
" L = get_L(params, min_diag=min_diag)\n",
" return jnp.matmul(L, L.T)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b54CKBYk4U51"
},
"outputs": [],
"source": [
"# @markdown We are going to compute the expectation ${E}_{q(\\color{blue}{\\underline{w}})}[\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}})]$ with Monte Carlo. For that we need to draw samples from $q$.\n",
"\n",
"# @markdown To sample from $\\mathcal{N}(\\mu_q, \\Sigma_{q})$, we are going to use a trick to turn samples from $\\mathcal{N}(0, I)$ into samples from q.\n",
"\n",
"# @markdown Specifically, we will do $\\mu_q + (L + \\sigma I) \\epsilon\\,\\,$ with $\\epsilon \\sim \\mathcal{N}(0, I)$\n",
"\n",
"# @markdown **code review task** <font color='green'>`Base`</font>: read through this code to verify it draws samples as described\n",
"\n",
"\n",
"def sample_weights(params, key, Nsamples=100):\n",
" \"\"\"\n",
" return a matrix of size (D, K) containing K samples from\n",
" our variational distribution q\n",
" \"\"\"\n",
" mu = params[\"w\"]\n",
" L = get_L(params, min_diag=1e-6)\n",
" eps = random.normal(key, shape=(mu.shape[0], Nsamples))\n",
" w = mu + jnp.matmul(L, eps).T\n",
" return w"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Jd_5rCowLRPQ"
},
"outputs": [],
"source": [
"# @title Implementation of KL divergence between Gaussians\n",
"\n",
"# @markdown this code computes $\\text{KL}(\\mathcal{N}(\\mu_q, \\Sigma_q) \\,||\\, \\mathcal{N}(0, \\alpha^{-1}I))$\n",
"\n",
"# @markdown In the interest of time, **you do not need to check this one** since we have not covered how to derive the expression.\n",
"\n",
"# @markdown If you are curious, [Here](https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/) is a nice resource showing the derivation\n",
"\n",
"\n",
"@jit\n",
"def rho_logdet(rho):\n",
" return 2 * jnp.sum(rho).clip(\n",
" a_min=-700\n",
" ) # empirically -730 is close to torch log(0)\n",
"\n",
"\n",
"@jit\n",
"def KLD_cost(params, prior_log_std):\n",
" \"\"\"\n",
" KL divergence between a full covariance Gaussian and an isotropic Gaussian\n",
"\n",
" Args:\n",
" prior_log_std: scalar or vector of size D\n",
" \"\"\"\n",
"\n",
" q_mu = params[\"w\"]\n",
" q_logdet = rho_logdet(params[\"sig\"])\n",
" q_Sig = get_Sig(params)\n",
"\n",
" d = q_Sig.shape[0]\n",
"\n",
" var_p = jnp.exp(2 * prior_log_std)\n",
"\n",
" Sig_p_inv = jnp.eye(d) * 1 / var_p\n",
" p_logdet = 2 * (prior_log_std * d).clip(a_min=-700)\n",
"\n",
" mu_sigma_inv_mu = q_mu @ Sig_p_inv @ q_mu\n",
" logdet_ratio = p_logdet - q_logdet\n",
" trace_term = jnp.trace(jnp.matmul(Sig_p_inv, q_Sig))\n",
"\n",
" KLD = 0.5 * (trace_term + mu_sigma_inv_mu - d + logdet_ratio)\n",
"\n",
" return KLD"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pq4fwvU2LRWN"
},
"outputs": [],
"source": [
"# @title Implementation of log-likelihood functions\n",
"# @markdown These functions compute $\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}})$ for the Gaussian and Bernouilli distributions respectively.\n",
"\n",
"# @markdown Remember the equivalency between log-likelihoods and standard loss functions described in the previous sections.\n",
"\n",
"# @markdown For regression, we are rellying on the scipy implementation. However, we might as well have written $-0.5||\\underline y - X\\color{blue}{\\underline{w}}||^2_2$.\n",
"\n",
"# @markdown For classification, we are rellying on the cross entropy loss.\n",
"\n",
"\n",
"@jit\n",
"def gaussian_ll(pred, y):\n",
" z = y - pred\n",
" return norm.logpdf(z, loc=0, scale=1)\n",
"\n",
"\n",
"def bernouilli_ll(pred, y):\n",
" return -sigmoid_binary_cross_entropy(pred, y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DodADk9AmTqN"
},
"source": [
"**Coding task** <font color='green'>`Base`</font>: putting it all together -- Implement the ELBO\n",
"\n",
"We have code to compute the KL term in the ELBO, we have code to sample from q and we have code to compute log-likelihoods $\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}_s})$\n",
"\n",
"You will implement $\\mathbb{E}_{q(\\color{blue}{\\underline{w}})}[\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}})]$ using a MonteCarlo estimator $\\frac{1}{N_{\\text{samples}}}\\sum_{s=1}^{N_{\\text{samples}}} \\log p(\\underline{y} | X, \\color{blue}{\\underline{w}_s})$ with $\\color{blue}{\\underline{w}_s} \\sim q(\\color{blue}{\\underline{w}})$\n",
"\n",
"* hint1: use the `sample_weights` method that was implemented in the cells above to draw samples. You can then push the samples through the `model` function to get predictions\n",
"\n",
"* hint2: `ll_func` which is taken as an input takes `pred` and `y` as inputs and returns $\\log p(\\underline{y} | X, \\color{blue}{\\underline{w}_s})$\n",
"\n",
"After this, the rest of the prac will just consist of running this code and studying how it behaves\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7GFRVB5fOtAQ"
},
"outputs": [],
"source": [
"def gen_Gaussian_ELBO(model, alpha, ll_func, Nsamples=100):\n",
" \"\"\"\n",
" Computes linear regression posterior parameters\n",
" \n",
" Args:\n",
" model: vmap model that takes X and w as inputs\n",
" alpha: scalar - regulasisation strength or prior precision\n",
" ll_func: function that maps from predictions and targets to log densities\n",
" Nsamples: integer, how many samples to draw for MonteCarlo estimator\n",
" Returns:\n",
" \n",
" Gaussian_ELBO: function which we can evaluate at our data points\n",
" and returns the ELBO estimate\n",
" \"\"\"\n",
" prior_log_std = - 0.5 * jnp.log(alpha)\n",
"\n",
" def ll_term(params, key, X, y):\n",
" w = # Your code goes here. \n",
" predictions = # Your code goes here. \n",
" ll = # Your code goes here. \n",
" #You need to return ll, a matrix of size (Nobservations, Nsamples) \n",
" return ll.sum(axis=0).mean(axis=0) # sum over observations and average over samples\n",
"\n",
" def Gaussian_ELBO(params, key, X, y):\n",
" KL = KLD_cost(params, prior_log_std)\n",
" ll = ll_term(params, key, X, y)\n",
" return ll - KL\n",
" \n",
" return jit(Gaussian_ELBO)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "Pe0uZVfNPGbF"
},
"outputs": [],
"source": [
"# @title Run me to test your code\n",
"\n",
"alpha = 1\n",
"\n",
"Gaussian_ELBO = gen_Gaussian_ELBO(\n",
" model=vmap_linear, alpha=alpha, ll_func=gaussian_ll, Nsamples=1000\n",
")\n",
"key = random.PRNGKey(0)\n",
"params = initialize_params(D=2, key=key)\n",
"objective = jit(\n",
" partial(\n",
" Gaussian_ELBO,\n",
" X=affine_basis(x_data_list_classification),\n",
" y=y_data_list_classification,\n",
" )\n",
")\n",
"\n",
"ELBO_correct = -28.404984\n",
"\n",
"assert (\n",
" jnp.abs(objective(params, key) - ELBO_correct) < 1e-3\n",
"), \"Covariance is not calculated correctly\"\n",
"\n",
"print(\"It seems correct. Look at the answer below to compare methods.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "mqaU3YzwPGdo"
},
"outputs": [],
"source": [
"# @title Answer (Try not to peek until you've given it a good try!')\n",
"\n",
"\n",
"def gen_Gaussian_ELBO(model, alpha, ll_func, Nsamples=100):\n",
" \"\"\"\n",
" Computes linear regression posterior parameters\n",
"\n",
" Args:\n",
" model: vmap model that takes X and w as inputs\n",
" alpha: scalar - regulasisation strength or prior precision\n",
" ll_func: function that maps from predictions and targets to log densities\n",
" Nsamples: integer, how many samples to draw for MonteCarlo estimator\n",
" Returns:\n",
"\n",
" Gaussian_ELBO: function which we can evaluate at our data points and returns the ELBO estimate\n",
" \"\"\"\n",
" prior_log_std = -0.5 * jnp.log(alpha)\n",
"\n",
" def ll_term(params, key, X, y):\n",
" w = sample_weights(params, key, Nsamples)\n",
" predictions = model(\n",
" X,\n",
" w.T,\n",
" )\n",
" ll = ll_func(predictions, y[:, None])\n",
" return ll.sum(axis=0).mean(\n",
" axis=0\n",
" ) # sum over observations and average over samples\n",
"\n",
" def Gaussian_ELBO(params, key, X, y):\n",
" KL = KLD_cost(params, prior_log_std)\n",
" ll = ll_term(params, key, X, y)\n",
" return ll - KL\n",
"\n",
" return jit(Gaussian_ELBO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ahCZzOcZECBh"
},
"source": [
"## Lets recover the linear regression exact posterior with variational inference\n",
"\n",
"First lets run a sanity check: we are going to run our black box variational inference algorithm on our linear regression problem. Since the true posterior is Gaussian, our Gaussian approximation $q$ should be able to recover it exactly.\n",
"\n",
"This will allow us to verify that our implementation does not have bugs!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "OPIdFgAIAI8H"
},
"outputs": [],
"source": [
"# @title Re-compute the linear reression true posterior and loss surface as a reference -- run this cell but there is no need to look at the code\n",
"\n",
"alpha = 3\n",
"\n",
"mu, covariance = BLR_posterior(\n",
" affine_basis(x_data_list_outlier), y_data_list_outlier, alpha\n",
")\n",
"\n",
"\n",
"def generate_loss_fun():\n",
" def loss_fun(w):\n",
" return regularised_linear_regression_loss(\n",
" X=affine_basis(x_data_list_outlier),\n",
" y=y_data_list_outlier,\n",
" w=w,\n",
" vmap_model=vmap_linear,\n",
" alpha=alpha,\n",
" )\n",
"\n",
" return loss_fun\n",
"\n",
"\n",
"loss_fun = generate_loss_fun()\n",
"\n",
"x0_grid, x1_grid, loss_grid = generate_loss_grid(\n",
" loss_fun=loss_fun, grid_size=200, lim0=[-8, 8], lim1=[-8, 8]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "pvnQ01eAwjqd"
},
"outputs": [],
"source": [
"# @title Plotting function for optimisation -- run this cell but there is no need to look at the code\n",
"\n",
"\n",
"def gen_plots(params):\n",
"\n",
" q_covariance = get_Sig(params, min_diag=1e-6)\n",
" q_mu = params[\"w\"]\n",
"\n",
" fig, ax = plt.subplots(1, 2, dpi=180, figsize=(10, 4))\n",
"\n",
" loss_landscape = ax[0].pcolormesh(\n",
" x0_grid,\n",
" x1_grid,\n",
" -loss_grid,\n",
" vmin=-loss_grid.min() * 5,\n",
" vmax=-loss_grid.min(),\n",
" cmap=\"viridis\",\n",
" ) #\n",
" plot_log_gaussian_ellipse(\n",
" ax=ax[0],\n",
" mean=mu,\n",
" cov=covariance,\n",
" color=\"r\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"exact posterior\",\n",
" MAP_size=25,\n",
" std_levels=[1, 2, 4, 6],\n",
" )\n",
" plot_log_gaussian_ellipse(\n",
" ax=ax[0],\n",
" mean=q_mu,\n",
" cov=q_covariance,\n",
" color=\"cyan\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"variational posterior\",\n",
" MAP_size=25,\n",
" std_levels=[1, 2, 4, 6],\n",
" )\n",
"\n",
" cbar = fig.colorbar(loss_landscape, ax=ax[0])\n",
" cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
" ax[0].set_ylabel(\"w\")\n",
" ax[0].set_xlabel(\"b\")\n",
" ax[0].grid(alpha=0.3)\n",
" ax[0].legend()\n",
" ax[0].set_ylim([-8, 8])\n",
" ax[0].set_xlim([-8, 8])\n",
" ax[0].set_title(\"comparing posteriors in weight space\")\n",
"\n",
" # Plot posterior predictive\n",
"\n",
" x_pred = jnp.linspace(-4, 4, 201)\n",
" X_pred = affine_basis(x_pred)\n",
"\n",
" xlim = [jnp.min(x_data_list_outlier) - 1, jnp.max(x_data_list_outlier) + 1]\n",
" ylim = [jnp.min(y_data_list_outlier) - 1, jnp.max(y_data_list_outlier) + 1]\n",
"\n",
" predictive_mean, predictive_std = BLR_predictions(X_pred, mu, covariance)\n",
" q_predictive_mean, q_predictive_std = BLR_predictions(X_pred, q_mu, q_covariance)\n",
"\n",
" errorfill(\n",
" x_pred,\n",
" predictive_mean,\n",
" predictive_std,\n",
" color=\"red\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax[1],\n",
" lw=2,\n",
" linestyle=\"-\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"exact predictive posterior\",\n",
" markevery=None,\n",
" )\n",
"\n",
" errorfill(\n",
" x_pred,\n",
" q_predictive_mean,\n",
" q_predictive_std,\n",
" color=\"cyan\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax[1],\n",
" lw=2,\n",
" linestyle=\"-\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"variational predictive posterior\",\n",
" markevery=None,\n",
" )\n",
"\n",
" ax[1].plot(x_data_list_outlier, y_data_list_outlier, \"ob\")\n",
" ax[1].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=ylim)\n",
" ax[1].grid(alpha=0.3)\n",
" ax[1].legend()\n",
" ax[1].set_title(\"comparing posteriors in funciton space\")\n",
"\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "GcdI9d22ECY6"
},
"outputs": [],
"source": [
"# @title Run me to train variational distribution on linear regression task\n",
"\n",
"\n",
"Gaussian_ELBO = gen_Gaussian_ELBO(\n",
" model=vmap_linear, alpha=alpha, ll_func=gaussian_ll, Nsamples=1000\n",
")\n",
"\n",
"\n",
"params = initialize_params(D=2, key=random.PRNGKey(0))\n",
"\n",
"objective = jit(\n",
" partial(Gaussian_ELBO, X=affine_basis(x_data_list_outlier), y=y_data_list_outlier)\n",
")\n",
"\n",
"\n",
"optimised_regularised_params = optimise(\n",
" objective,\n",
" params=params,\n",
" plotting_func=gen_plots,\n",
" LR=5e-3,\n",
" MAX_STEPS=1500,\n",
" LOG_EVERY=50,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vBE9-bB9yVMO"
},
"source": [
"**Group task** <font color='green'>`Base`</font> - Discuss with a neighbour and / or tutor:\n",
"\n",
"* Under what conditions does our variational inference algorithm recover the true posterior, i.e. $q(\\color{blue}{\\underline{w}})= p(\\color{blue}{\\underline{w}} | X, \\underline{y})$? Did we recover the exact posterior here?\n",
"\n",
"* Do you think we can overfit with this variational inference algorithm? Why?\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eoJmOtd2EJMN"
},
"source": [
"## Lets approximate the Logistic regression posterior with variational inference\n",
"\n",
"Now that we have verified that our black box variational inference algorithm does the right thing, lets try to apply it for the logistic regression problem.\n",
"\n",
"Remember from [Section 2](#S2) that in logistic regression we are targettting a non-Gaussian posterior that is proportional to \n",
"\n",
"<center>\n",
"$p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) = \\prod_{n=1}^N Bern(y_n ; \\phi(x_n \\color{blue}{\\underline{w}})) \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)$\n",
"</center>\n",
"\n",
"Lets run our same variational inference algorithm and see how it performs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "JiHDAtE7ymKH"
},
"outputs": [],
"source": [
"# @title pre-compute loss surface and make plotting function -- run this cell but there is no need to look at the code\n",
"\n",
"alpha = 0.75\n",
"\n",
"\n",
"def generate_loss_fun():\n",
" def loss_fun(w):\n",
" return logistic_regression_loss(\n",
" X=affine_basis(x_data_list_classification),\n",
" y=y_data_list_classification,\n",
" vmap_model=vmap_linear,\n",
" w=w,\n",
" alpha=alpha,\n",
" )\n",
"\n",
" return loss_fun\n",
"\n",
"\n",
"loss_fun = generate_loss_fun()\n",
"\n",
"x0_grid, x1_grid, loss_grid = generate_loss_grid(\n",
" loss_fun=loss_fun, grid_size=200, lim0=[-8, 8], lim1=[-6, 10]\n",
")\n",
"\n",
"\n",
"def gen_plots(params):\n",
"\n",
" # Plot loss landscape and variational distribution\n",
" q_covariance = get_Sig(params, min_diag=1e-6)\n",
" q_mu = params[\"w\"]\n",
"\n",
" x_pred = np.linspace(-4, 4, 100)\n",
" X_pred = affine_basis(x_pred)\n",
"\n",
" Nsamples = 20\n",
" q_parameter_samples = jax.random.multivariate_normal(\n",
" key, q_mu, q_covariance, shape=(Nsamples,)\n",
" )\n",
" q_sample_preds = vmap_linear(X_pred, q_parameter_samples.T)\n",
" q_prob_samples = jax.nn.sigmoid(q_sample_preds)\n",
"\n",
" fig, ax = plt.subplots(1, 2, dpi=160, figsize=(10, 4))\n",
"\n",
" loss_landscape = ax[0].pcolormesh(\n",
" x0_grid,\n",
" x1_grid,\n",
" -loss_grid,\n",
" vmin=-loss_grid.min() * 10,\n",
" vmax=-loss_grid.min(),\n",
" cmap=\"viridis\",\n",
" ) #\n",
"\n",
" plot_log_gaussian_ellipse(\n",
" ax=ax[0],\n",
" mean=q_mu,\n",
" cov=q_covariance,\n",
" color=\"cyan\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"variational posterior\",\n",
" MAP_size=25,\n",
" std_levels=[1, 2, 4, 6],\n",
" )\n",
"\n",
" ax[0].scatter(\n",
" q_parameter_samples[:, 0],\n",
" q_parameter_samples[:, 1],\n",
" 5,\n",
" c=\"cyan\",\n",
" label=\"samples\",\n",
" )\n",
" cbar = fig.colorbar(loss_landscape, ax=ax[0])\n",
" cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
" ax[0].set_ylabel(\"w\")\n",
" ax[0].set_xlabel(\"b\")\n",
" ax[0].grid(alpha=0.3)\n",
" ax[0].legend()\n",
" ax[0].set_ylim([-6, 10])\n",
" ax[0].set_xlim([-8, 8])\n",
" ax[0].set_title(\"evaluating variational Bayesian posterior \\n in weight space\")\n",
"\n",
" for i, sample in enumerate(q_prob_samples.T):\n",
" if i == 0:\n",
" ax[1].plot(\n",
" x_pred,\n",
" sample,\n",
" \"-\",\n",
" color=\"cyan\",\n",
" lw=0.4,\n",
" alpha=0.4,\n",
" label=\"sample class probability\",\n",
" )\n",
" ax[1].plot(x_pred, sample, \"-\", color=\"cyan\", lw=0.4, alpha=0.4)\n",
"\n",
" ax[1].plot(\n",
" x_pred,\n",
" q_prob_samples.mean(axis=1),\n",
" \"--\",\n",
" color=\"cyan\",\n",
" lw=2,\n",
" label=\"variational posterior class probabilities\",\n",
" )\n",
"\n",
" mode_predictions = sigmoid(\n",
" vmap_linear(affine_basis(x_pred), optimised_regularised_classification_params)\n",
" )\n",
" ax[1].plot(x_pred, mode_predictions, c=\"red\", label=\"optimised class probabilities\")\n",
"\n",
" # Plot posterior predictive\n",
"\n",
" ax[1].plot(x_data_list_classification, y_data_list_classification, \"ob\")\n",
" ax[1].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=xlim, ylim=[-0.1, 1.1])\n",
" ax[1].grid(alpha=0.3)\n",
" ax[1].legend()\n",
" ax[1].set_title(\n",
" \"comparing regularised solution and \\n variational Bayesian solution in function space\"\n",
" )\n",
"\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "PFVX1IljECfP"
},
"outputs": [],
"source": [
"# @title Run me to train variational distribution on logistic regression task\n",
"\n",
"\n",
"Gaussian_ELBO = gen_Gaussian_ELBO(\n",
" model=vmap_linear, alpha=alpha, ll_func=bernouilli_ll, Nsamples=100\n",
") # Nottice we use the bernoulli likelihood now\n",
"\n",
"\n",
"params = initialize_params(D=2, key=random.PRNGKey(0))\n",
"\n",
"objective = jit(\n",
" partial(\n",
" Gaussian_ELBO,\n",
" X=affine_basis(x_data_list_classification),\n",
" y=y_data_list_classification,\n",
" )\n",
")\n",
"\n",
"\n",
"optimised_regularised_params = optimise(\n",
" objective,\n",
" params=params,\n",
" plotting_func=gen_plots,\n",
" LR=5e-3,\n",
" MAX_STEPS=2000,\n",
" LOG_EVERY=50,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "46aAuXhXkOcY"
},
"source": [
"**Group task** <font color='green'>`Base`</font> - Discuss with a neighbour and / or tutor:\n",
"\n",
"* How many parameters are we learning when we perform variational inference?\n",
"\n",
"* Are we reaching exact posterior in the logistic regression case?\n",
"\n",
"* What are potential drawbacks of our variational inference algorithm?\n",
"\n",
"* How does the variational posterior predictive class probability compare with the regularised optimisation solution? Is it more or less confident?\n",
"\n",
"* Do all the samples represent reasonable solutions? If not, why?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y6pNsO1oJ9dV"
},
"source": [
"## Takeaways\n",
"\n",
"* We have implemented black box variational inference using a Gaussian variationa posterior. This algorithm allows us to **approximate any model's posterior** distribution with a Gaussian variational posterior, thus the name ''black box''.\n",
"\n",
"* When the true posterior is Gaussian, our variational distribution will recover it exactly. When the true posterior is not Gaussian, then our variational distribution will choose the closest Gaussian in a KL divergence sense.\n",
"\n",
"* Since our optimisation is minimising the KL divergence to the true posterior, optimising more can only get us closer to the true posterior. We can not overfit with variational inference. \n",
"\n",
"## End of Section\n",
"\n",
"The rest of the section contains more advanced optional contents. You can skip to [Section 4](#S4) if you want."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pxh4QTJ2G-7T"
},
"source": [
"**Group task** (Optional) <font color='orange'>`Intermediate`</font> - Discuss with a neighbour and / or tutor:\n",
"\n",
"Remember that in the linear regression case we estimated both a predictive mean and predictive standard deviation from samples. In logistic regression, each sample provides us with a curve of class probability. We have computed the variational predictive mean class probability as the average of these probabilities and shown it with the dashed cyan line.\n",
"\n",
"* Does it make sense to return standard deviation over samples like we did in the regression case? \n",
"\n",
"* How should we combine information from samples in logistic regression?\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HmacEbJ0Zw9M"
},
"source": [
"# Section 4: Bayesian Neural Networks <a name=\"S4\"></a>\n",
"\n",
"Now that we have our black-box VI algorithm coded up, its time to appy it to neural networks. We are going to see that when switching from linear and logistic regression to deep models, things become complicated fast.\n",
"\n",
"This is the final core section of this prac. Spend as much time as you want on it but keep in mind that there is an **additional optional section 5 at the end where you can learn about Hamiltonian Monte Carlo**."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ZlVqE2wMIxCq"
},
"outputs": [],
"source": [
"# @title Lets start with a very simple dataset, one that is very close to the indentity function $\\,\\,y=x$\n",
"\n",
"x_data_list_identity = jnp.array([-2, -1.8, -1, 2, 2, 2.1])\n",
"y_data_list_identity = jnp.array([-2, -1.9, -1, 2, 1.9, 2.1])\n",
"\n",
"plot_basic_data(x_data_list_identity, y_data_list_identity, ylim=[-3, 3])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4ed8Wt0AOJnf"
},
"outputs": [],
"source": [
"# @title Lets start with a very basic NN\n",
"# @markdown Specifically, a 1 hidden layer, & 1 hidden unit NN with no non-linearity: $f(x, \\color{blue}{\\underline{w}}) = \\color{blue}{\\underline{w}_{1}} \\color{blue}{\\underline{w}_{0}} x$\n",
"\n",
"# @markdown This model is linear in the inputs and weights.\n",
"\n",
"\n",
"def NN(x, params):\n",
" return params[1] * params[0] * x\n",
"\n",
"\n",
"vmap_NN = jit(vmap(NN, in_axes=(0, None)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "9s1cnEsgZ-9-"
},
"outputs": [],
"source": [
"# @title ## Plot the loss landscape. It is non-convex. Its multimodal!\n",
"# @markdown Even though we are trying to learn a simple identity function with a linear NN.\n",
"\n",
"# @markdown This is due to the symmetries in the weight space between $\\color{blue}{\\underline{w}_{1}}$ and $\\color{blue}{\\underline{w}_{0}}$. You can flip these weights around and obtain the same solution!\n",
"\n",
"\n",
"alpha = 1\n",
"\n",
"\n",
"def generate_loss_fun():\n",
" def loss_fun(w):\n",
" return regularised_linear_regression_loss(\n",
" X=x_data_list_identity,\n",
" y=y_data_list_identity,\n",
" w=w,\n",
" vmap_model=vmap_NN,\n",
" alpha=alpha,\n",
" )\n",
"\n",
" return loss_fun\n",
"\n",
"\n",
"loss_fun = generate_loss_fun()\n",
"\n",
"x0_grid, x1_grid, loss_grid = generate_loss_grid(\n",
" loss_fun=loss_fun, grid_size=300, lim0=[-5, 5], lim1=[-5, 5]\n",
")\n",
"\n",
"fig = plt.figure(dpi=130)\n",
"loss_landscape = plt.pcolormesh(\n",
" x0_grid,\n",
" x1_grid,\n",
" -loss_grid,\n",
" vmin=-loss_grid.min() * 5,\n",
" vmax=-loss_grid.min(),\n",
" cmap=\"viridis\",\n",
") #\n",
"\n",
"plt.xlabel(\"w0\")\n",
"plt.ylabel(\"w1\")\n",
"ax = [plt.gca()]\n",
"cbar = fig.colorbar(loss_landscape, ax=ax[0])\n",
"cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
"plt.title(\"Simple linear NN loss landsape\")\n",
"plt.grid(alpha=0.3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T4p0hCdqbhYV"
},
"source": [
"## Our unimodal Gaussian approximation will struggle with the multimodal posterior but lets apply the black box variational inference algorithm anyway!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "rKo7u2htToma"
},
"outputs": [],
"source": [
"# @title define plotting method -- run this cell but there is no need to look at the code\n",
"\n",
"\n",
"def gen_plots(params):\n",
" # Plot loss landscape and variational distribution\n",
" q_covariance = get_Sig(params, min_diag=1e-6)\n",
" q_mu = params[\"w\"]\n",
"\n",
" fig, ax = plt.subplots(1, 2, dpi=160, figsize=(10, 4))\n",
"\n",
" loss_landscape = ax[0].pcolormesh(\n",
" x0_grid,\n",
" x1_grid,\n",
" -loss_grid,\n",
" vmin=-loss_grid.min() * 5,\n",
" vmax=-loss_grid.min(),\n",
" cmap=\"viridis\",\n",
" ) #\n",
" plot_log_gaussian_ellipse(\n",
" ax=ax[0],\n",
" mean=q_mu,\n",
" cov=q_covariance,\n",
" color=\"cyan\",\n",
" alpha=1,\n",
" lw=1,\n",
" label=\"variational posterior\",\n",
" MAP_size=25,\n",
" std_levels=[1, 2, 4, 6],\n",
" )\n",
" ax[0].set_ylabel(\"w\")\n",
" ax[0].set_xlabel(\"b\")\n",
" ax[0].grid(alpha=0.3)\n",
" ax[0].legend()\n",
" ax[0].set_ylim([-5, 5])\n",
" ax[0].set_xlim([-5, 5])\n",
" cbar = fig.colorbar(loss_landscape, ax=ax[0])\n",
" cbar.set_label(\"Regularised Loss\", rotation=270, labelpad=20)\n",
" ax[0].set_title(\"variational distribution fit to true posterior\")\n",
"\n",
" # Plot posterior predictive\n",
"\n",
" x_pred = np.linspace(-5, 5, 100)\n",
" Nsamples = 20\n",
" q_parameter_samples = jax.random.multivariate_normal(\n",
" key, q_mu, q_covariance, shape=(Nsamples,)\n",
" )\n",
" q_sample_preds = vmap_NN(x_pred, q_parameter_samples.T)\n",
"\n",
" for i, sample in enumerate(q_sample_preds.T):\n",
" if i == 0:\n",
" ax[1].plot(\n",
" x_pred,\n",
" sample,\n",
" \"-\",\n",
" color=\"cyan\",\n",
" lw=0.4,\n",
" alpha=1,\n",
" label=\"sample predictions\",\n",
" )\n",
" ax[1].plot(x_pred, sample, \"-\", color=\"cyan\", lw=0.4, alpha=1)\n",
"\n",
" errorfill(\n",
" x_pred,\n",
" q_sample_preds.mean(axis=1),\n",
" q_sample_preds.std(axis=1),\n",
" color=\"cyan\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax[1],\n",
" lw=2.5,\n",
" linestyle=\"--\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"preditive mean\",\n",
" markevery=None,\n",
" )\n",
"\n",
" ax[1].plot(x_data_list_identity, y_data_list_identity, \"ob\")\n",
" ax[1].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=[-3, 3], ylim=[-4, 4])\n",
" ax[1].grid(alpha=0.3)\n",
" ax[1].set_title(\"Predictive posterior fit\")\n",
" ax[1].legend()\n",
"\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8oUNOkcHRg1f"
},
"outputs": [],
"source": [
"# @title **Plotting task** <font color='geen'>`Base`</font> Try setting `large_init` to both true and false and optimising the variational parameters.\n",
"# @markdown What changes?\n",
"\n",
"large_init = True # @param ['True', 'False'] {type:\"raw\"}\n",
"\n",
"Gaussian_ELBO = gen_Gaussian_ELBO(\n",
" model=vmap_NN, alpha=alpha, ll_func=gaussian_ll, Nsamples=100\n",
")\n",
"\n",
"\n",
"params = initialize_params(D=2, key=random.PRNGKey(42))\n",
"if large_init:\n",
" params[\"w\"] = params[\"w\"] * 5\n",
"\n",
"\n",
"objective = jit(partial(Gaussian_ELBO, X=x_data_list_identity, y=y_data_list_identity))\n",
"\n",
"\n",
"optimise(\n",
" objective,\n",
" params=params,\n",
" plotting_func=gen_plots,\n",
" LR=5e-3,\n",
" MAX_STEPS=2000,\n",
" LOG_EVERY=50,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ugV_8P7eUY1v"
},
"source": [
"\n",
"\n",
"**Group task <font color='green'>`Base`</font>** - Discuss with a neighbour and / or tutor:\n",
"\n",
"* What parameter settings, i.e. values of $\\color{blue}{\\underline{w}_{1}}$ & $\\color{blue}{\\underline{w}_{0}}$ lead usto recover the identity function?\n",
"\n",
"* Has your experimentation with the `large_init` option led you to discover a potential failure mode of our black box variational inference algorithm?\n",
"\n",
"* When we do converge to one of the modes, does it matter which one we converge to? Are both the same?\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e3JzXp74Z_po"
},
"source": [
"## Exploring the limitations of Variational Inference for BNNs \n",
"\n",
"\n",
"Now that we have seen a potential issue that can stem from multimodal loss landscapes, lets try a slightly more realistic example: a 4 parameter neural network complete with weights and biases. We will try to fit the same identity dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1aLad3ceagHh"
},
"outputs": [],
"source": [
"# @title Our NN is now $f(x, \\color{blue}{\\underline{w}}) = \\color{blue}{w_1} \\text{ReLU}(\\color{blue}{w_0} x + \\color{blue}{b_0}) + \\color{blue}{b_1}$\n",
"\n",
"\n",
"def NN(x, params):\n",
" return params[2] * jax.nn.relu(x * params[0] + params[1]) + params[3]\n",
"\n",
"\n",
"vmap_NN = jit(vmap(NN, in_axes=(0, None)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "-UdB8aKXmf5f"
},
"outputs": [],
"source": [
"# @title plot generation code -- run this cell but there is no need to look at the code\n",
"\n",
"\n",
"def gen_plots(params):\n",
" # Plot loss landscape and variational distribution\n",
" q_covariance = get_Sig(params, min_diag=1e-6)\n",
" q_MAP = params[\"w\"]\n",
"\n",
" # Plot posterior predictive\n",
"\n",
" x_pred = np.linspace(-5, 5, 100)\n",
" Nsamples = 20\n",
" q_parameter_samples = jax.random.multivariate_normal(\n",
" key, q_MAP, q_covariance, shape=(Nsamples,)\n",
" )\n",
" q_sample_preds = vmap_NN(x_pred, q_parameter_samples.T)\n",
"\n",
" plt.figure(dpi=120)\n",
"\n",
" ax = plt.gca()\n",
"\n",
" for i, sample in enumerate(q_sample_preds.T):\n",
" if i == 0:\n",
" ax.plot(\n",
" x_pred,\n",
" sample,\n",
" \"-\",\n",
" color=\"cyan\",\n",
" lw=0.4,\n",
" alpha=1,\n",
" label=\"sample predictions\",\n",
" )\n",
" ax.plot(x_pred, sample, \"-\", color=\"cyan\", lw=0.4, alpha=1)\n",
"\n",
" errorfill(\n",
" x_pred,\n",
" q_sample_preds.mean(axis=1),\n",
" q_sample_preds.std(axis=1),\n",
" color=\"cyan\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax,\n",
" lw=2.5,\n",
" linestyle=\"--\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"preditive mean\",\n",
" markevery=None,\n",
" )\n",
"\n",
" ax.plot(x_data_list_identity, x_data_list_identity, \"ob\")\n",
" ax.set_title(\"Predictive posterior fit\")\n",
" ax.set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=[-3, 3], ylim=[-4, 4])\n",
" plt.grid(alpha=0.3)\n",
" plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "SFzQ6da2agCT"
},
"outputs": [],
"source": [
"# @title **Plotting task** <font color='green'>`Base`</font>: Optimise the variational parameters with different seeds\n",
"# @markdown We cant plot the loss landscape anymore because it is 4 dimensional. But we can always plot the predictive posterior\n",
"\n",
"# @markdown Try using different seeds: 0 and 42\n",
"\n",
"alpha = 0.1\n",
"\n",
"Gaussian_ELBO = gen_Gaussian_ELBO(\n",
" model=vmap_NN, alpha=alpha, ll_func=gaussian_ll, Nsamples=100\n",
")\n",
"\n",
"Seed = 43 # @param {type:\"integer\"}\n",
"\n",
"params = initialize_params(D=4, key=random.PRNGKey(Seed))\n",
"\n",
"\n",
"objective = jit(partial(Gaussian_ELBO, X=x_data_list_identity, y=y_data_list_identity))\n",
"\n",
"optimise(\n",
" objective,\n",
" params=params,\n",
" plotting_func=gen_plots,\n",
" LR=2e-3,\n",
" MAX_STEPS=3000,\n",
" LOG_EVERY=50,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9K233sZpWruk"
},
"source": [
"\n",
"**Group task <font color='green'>`Base`</font>**: Discuss with a neighbour and / or tutor:\n",
"\n",
"* How do the solutions found by the seeds 0 and 42 differ? What does this tell you about the loss landscape?\n",
"\n",
"* Can you think up with a setting of the parameters ($b_0, w_0, b_1, w_1$) that recovers the identity?\n",
"\n",
"* Why doesnt our optimiser find this configuration?\n",
"\n",
"* Do the errorbars look reasonable?\n",
"\n",
"* Do you think that our black box variational inference algorithm is a good tool for use with neural network models? How could it be improved? \n",
" * Hint: what if we averaged the solutions obtained with seeds 0 and 42?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5uOjjqT7pV41"
},
"source": [
"## Takeaways\n",
"\n",
"* Deep architectures result in multimodal loss functions and posteriors. This poses a challenge for our black box variational inference algorithm.\n",
"\n",
"* Potential issues that may arise are: \n",
" 1. the optimiser getting stuck in a local optima between modes where a lot of variational posterior mass is placed in high loss regions and thus we obtain a poor data fit.\n",
" 2. the optimiser identifying a mode of the loss that corresponds to an overly complicated solution that will not generalise: like fitting our \"identity\" dataset with an angled line.\n",
"\n",
"## -- End of Prac --\n",
"\n",
"The [next section](#S5) contains more advanced optional contents. You can skip to the [conclusions](#C) if you want."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I_i353LzETrQ"
},
"source": [
"# Section 5: Hamiltonian Monte Carlo (Optional) - <font color='orange'>`Intermediate`</font> <a name=\"S5\"></a>\n",
"\n",
"In the previous section, our Gaussian variational distribution was only able to identify individual modes of the posterior distribution and these often corresponded to poor predictions or poor errorbars.\n",
"\n",
"Now, we will look at an approach for approximate Bayesian inference which is capable of working well even with multimodal posteriors: Hamiltonian Monte Carlo (HMC). This sampling method was introduced in [Radford Neal's thesis](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.446.9306&rep=rep1&type=pdf). HMC is designed to try to reach as many modes of the posterior as possible, giving an accurate approximation even for deep networks.\n",
"We will not cover the math behind this method in detail but we will explain how to use it. \n",
"\n",
"To use HMC, we just require access to an unnormalised posterior density, i.e. a loss function. As we saw in [Section 1](#S1), we already have this in the form of our regularised least squares objective: \n",
"\n",
"<center>\n",
"$\\log \\left( p(\\underline{y} | X, \\color{blue}{\\underline{w}}) p(\\color{blue}{\\underline{w}}) \\right) = \\log \\left( \\mathcal{N}(\\underline{y}; X\\cdot \\color{blue}{\\underline{w}}, I) \\mathcal{N}(\\color{blue}{\\underline{w}}; 0, \\alpha^{-1} I)\\right)= \\frac{-1}{2}||\\underline{y} - X\\cdot \\color{blue}{\\underline{w}}||_{2}^2 + \\frac{-\\alpha}{2} ||\\color{blue}{\\underline{w}}||^2_2 + C$\n",
"</center>\n",
"\n",
"\n",
"\n",
"As a Monte Carlo method, HMC will provide us with samples from the posterior $p(\\color{blue}{\\underline{w}} | X, \\underline{y})$. We can then use these samples to estimate quantities of interest. In our case this will be the predictive mean and predictive standard deviation.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3yn2ybsPotYd"
},
"source": [
"We are going to use [NumPyro](https://github.com/pyro-ppl/numpyro) an HMC library that works with jax."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bux8haYJSdGn"
},
"outputs": [],
"source": [
"# @title Here we define our numpyro model\n",
"\n",
"# @markdown We specify a priors over our weights n with `numpyro.sample(\"w\"...)` and the likelihood function `numpyro.sample(name=\"err\"..., obs=err)`. The `obs` parameter states that we observe the targets, making this a likelihood funciton.\n",
"\n",
"# @markdown Although HMC can work with the loss function directly, and numpyro supports this, it is more natural to specify our model in terms of prior and likelihood in numpyro.\n",
"\n",
"import numpyro\n",
"from numpyro.infer import MCMC, NUTS, init_to_value\n",
"import numpyro.distributions as dist\n",
"\n",
"\n",
"def NN_numpyro_model(y, X, D, model, alpha):\n",
" w = numpyro.sample(\"w\", dist.Normal(jnp.zeros(D), (alpha**-0.5) * jnp.ones(D)))\n",
" preds = model(X, w)\n",
" err = y - preds\n",
" numpyro.sample(name=\"err\", fn=dist.Normal(0, jnp.ones(len(y))), obs=err)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_pinWVMQzzxp"
},
"outputs": [],
"source": [
"# @title Code to draw samples -- We use a standard configuraition\n",
"\n",
"\n",
"def draw_samples(\n",
" model, warmup, n_samples, thinning, num_chains, init_dict={}, max_tree_depth=12\n",
"):\n",
"\n",
" nuts_kernel = NUTS(\n",
" model,\n",
" forward_mode_differentiation=False,\n",
" max_tree_depth=max_tree_depth,\n",
" init_strategy=init_to_value(values=init_dict),\n",
" )\n",
" mcmc = MCMC(\n",
" nuts_kernel,\n",
" num_warmup=warmup,\n",
" num_samples=n_samples,\n",
" thinning=thinning,\n",
" num_chains=num_chains,\n",
" )\n",
" rng_key = random.PRNGKey(0)\n",
"\n",
" mcmc.run(\n",
" rng_key,\n",
" )\n",
" mcmc.print_summary()\n",
" return mcmc.get_samples(group_by_chain=False)\n",
"\n",
"\n",
"alpha = 0.05\n",
"\n",
"sampling_model = partial(\n",
" NN_numpyro_model,\n",
" y=y_data_list_identity,\n",
" X=x_data_list_identity,\n",
" D=4,\n",
" model=vmap_NN,\n",
" alpha=alpha,\n",
")\n",
"samples = draw_samples(\n",
" sampling_model, warmup=5000, n_samples=5000, thinning=10, num_chains=1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "MXR7i1NxyZDv"
},
"outputs": [],
"source": [
"# @title Lets see the results\n",
"\n",
"# @markdown We should get a much smoother regression line than with VI\n",
"\n",
"\n",
"sample_preds = vmap_NN(x_pred, samples[\"w\"].T)\n",
"\n",
"fig, ax = plt.subplots(1, 2, dpi=120, figsize=(9, 4))\n",
"\n",
"\n",
"for i, sample in enumerate(sample_preds.T[::20]):\n",
" if i == 0:\n",
" ax[0].plot(\n",
" x_pred,\n",
" sample,\n",
" \"-\",\n",
" color=\"cyan\",\n",
" lw=0.4,\n",
" alpha=1,\n",
" label=\"sample predictions\",\n",
" )\n",
" ax[0].plot(x_pred, sample, \"-\", color=\"cyan\", lw=0.4, alpha=1)\n",
"\n",
"\n",
"ax[0].plot(x_data_list_identity, x_data_list_identity, \"ob\")\n",
"ax[0].set_title(\"sample predictions\")\n",
"ax[0].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=[-3, 3], ylim=[-4, 4])\n",
"ax[0].grid(alpha=0.3)\n",
"\n",
"errorfill(\n",
" x_pred,\n",
" sample_preds.mean(axis=1),\n",
" sample_preds.std(axis=1),\n",
" color=\"cyan\",\n",
" alpha_fill=0.2,\n",
" line_alpha=1,\n",
" ax=ax[1],\n",
" lw=2.5,\n",
" linestyle=\"--\",\n",
" fill_linewidths=0.2,\n",
" marker=None,\n",
" markersize=1,\n",
" label=\"preditive mean\",\n",
" markevery=None,\n",
")\n",
"\n",
"\n",
"ax[1].plot(x_data_list_identity, x_data_list_identity, \"ob\")\n",
"ax[1].set_title(\"Predictive mean and standard deviation errorbars\")\n",
"ax[1].set(xlabel=\"Input x\", ylabel=\"Output y\", xlim=[-3, 3], ylim=[-4, 4])\n",
"ax[1].grid(alpha=0.3)\n",
"# plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YBagBCgk2ERY"
},
"source": [
"\n",
"**Group task <font color='orange'>`Intermediate`</font>**: Discuss with a neighbour and / or tutor:\n",
"\n",
"* From the sample predictions, which modes has the sampler found? Has the sampler found all the modes we found with VI? Has the sampler found any new modes that VI failed to find?\n",
"\n",
"* Do the errorbars look more or less reasonable than those obtained from our VI algorithm?\n",
"\n",
"* How well do you think this method will scale to neural networks with more parameters, i.e. higher dimensional weight spaces?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AQDjHgGI2xlS"
},
"source": [
"## Takeaways\n",
"\n",
"* HMC is a very powerful method that currently represents the state-of-the art for inference accuracy in neural networks. It has usage beyond inference in NNs however. It is a great tool to apply to any problem where we need to sample from multimodal unormalised distributions.\n",
"\n",
"* However HMC is not perfect. It has serious limitations:\n",
" 1. It is not guaranteed to find every posterior mode. For instance, our HMC sampler failed to find the mode in which the function stays flat from -4 to 0 and then turns upward. This limitation gets worse in higher dimensions where different modes can be very far apart and are thus very hard to find.\n",
" 2. HMC is expensive to run. Specifically, HMC is a full batch method. For large datasets, each HMC step requires a full pass through the dataset. A standard HMC sampler delivers around 1 sample per 10 steps. This means that we need to make thousands of passes through our full dataset to get a couple hundred samples. This is not possible in modern settings. VI can be minibatched and thus does not suffer from this limitation.\n",
"\n",
"* HMC is a more accurate but more expensive alternative to VI. Keep this in mind when choosing which algorithm to use for each task.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fV3YG7QOZD-B"
},
"source": [
"# Prac Conclusion <a name=\"C\"></a>\n",
"**Summary:**\n",
"\n",
"* Bayesian inference differs from traditional loss minimisation learning in that instead of retuning a single best parameter estimate, Bayesian inference returns every possible parameter setting weighed by how well these agree with the prior and the data.\n",
"\n",
"* The disagreement among plausible parameter settings induces uncertainty in the predictions.\n",
"\n",
"* Bayesian inference is hard because it requires characterising how well all parameter settings agree with the prior and data, i.e. characterising the posterior distribution. This is only possible to do analyticially in a very reduced range of models, such as the linear-Gaussian case.\n",
"\n",
"* When the exact posterior is not tractable, we have to resort to approximations, such as variational inference or HMC. \n",
"\n",
"* Approximations present a tradeoff between fidelity and computational cost. VI is often cheaper but less accurate while HMC is expensive but accurate.\n",
"\n",
"**Next Steps:** \n",
"\n",
"If you enjoyed this prac and want to learn more about Bayesian inference, a great resource is [Pattern Recognition and\n",
"Machine Learning](http://users.isr.ist.utl.pt/~wurmd/Livros/school/Bishop%20-%20Pattern%20Recognition%20And%20Machine%20Learning%20-%20Springer%20%202006.pdf) by Chris Bishop. Specifically, Chapters 2 to 5 cover the contents of this prac in more detail: probability distributions (2), linear regression (3), logistic regression (4) and neural networks (5).\n",
"\n",
"Here are some keystone papers from the field of Bayesian Deep Learning:\n",
"\n",
"* [David Mackay's thesis](https://www.inference.org.uk/mackay/thesis.pdf) -- Introduces the linearised Laplace approximation\n",
"* [Radford Neal's thesis](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.446.9306&rep=rep1&type=pdf) -- introduces HMC\n",
"* [Probabilistic Backpropagation](https://arxiv.org/abs/1502.05336) -- This paper employs a form of variational inference called expectation propagation\n",
"* [Bayes by Backprop](https://arxiv.org/abs/1505.05424) -- this one implements the black box VI algorithm that we coded in [Section 3](#S3)\n",
"* [Black-box α-divergence Minimization](https://arxiv.org/abs/1511.03243) -- This paper employs a form of variational inference called alpha divergence minimisation\n",
"* [Deep Ensembles](https://arxiv.org/abs/1612.01474?context=cs) -- this method does not perform Bayesian inference but it uses a lot of the same techniques as Bayesian inference and it performs very well.\n",
"\n",
"Currently, the best performing Bayesian Deep Learning method that scales to modern neural networks is [modernised Linearised Laplace](https://arxiv.org/abs/2008.08400). Apart from providing accurate errorbars, this method computes an [estimate of the model evidence for the neural network which can be used to tune hyperparameters](https://arxiv.org/abs/2206.08900).\n",
"\n",
"\n",
"\n",
"-----\n",
"For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o1ndpYE50BpG"
},
"source": [
"## Feedback\n",
"\n",
"Please provide feedback that we can use to improve our practicals in the future."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OIZvkhfRz9Jz"
},
"outputs": [],
"source": [
"# @title Generate Feedback Form. (Run Cell)\n",
"from IPython.display import HTML\n",
"\n",
"HTML(\n",
" \"\"\"\n",
"<iframe \n",
"\tsrc=\"https://forms.gle/bvLLPX74LMGrFefo9\",\n",
" width=\"80%\" \n",
"\theight=\"1200px\" >\n",
"\tLoading...\n",
"</iframe>\n",
"\"\"\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oglV4kHMWnIN"
},
"source": [
"<img src=\"https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png\" width=\"50%\" />"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Updated_Bayesian_Deep_Learning_Prac.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: practicals/GNN_practical.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "obBt50HcYEC_"
},
"source": [
"# Graph Neural Networks\n",
"\n",
"<img src=\"https://miro.medium.com/max/1400/1*ZMIwWfua5zG_C4hcN_4I_g.png\" width=\"60%\" />\n",
"\n",
"<a href=\"https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/practicals/GNN_practical.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
"\n",
"© Deep Learning Indaba 2022. Apache License 2.0.\n",
"\n",
"**Authors:**\n",
"Matthew Morris, Tom Makkink, and Jama Hussein Mohamud.\n",
"\n",
"Credit to Lisa Wang and Nikola Jovanović for providing much of the basic [content](https://github.com/deepmind/educational/blob/master/colabs/summer_schools/intro_to_graph_nets_tutorial_with_jraph.ipynb).\n",
"\n",
"**Introduction:** \n",
"\n",
gitextract_dumor4er/
├── LICENSE
├── README.MD
└── practicals/
├── Bayesian_Deep_Learning_Prac.ipynb
├── GNN_practical.ipynb
├── Indaba_2022_Prac_Template.ipynb
├── Introduction_to_ML_using_JAX.ipynb
├── array_algebra.ipynb
├── attention_and_transformers.ipynb
├── deep_generative_models.ipynb
└── introduction_to_reinforcement_learning.ipynb
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4,600K chars).
[
{
"path": "LICENSE",
"chars": 11351,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.MD",
"chars": 7317,
"preview": "# Deep Learning Indaba Practicals 2022\n \n## The Practicals \n| Topic 💥 | Description 📘 |\n|:--- |------------------------"
},
{
"path": "practicals/Bayesian_Deep_Learning_Prac.ipynb",
"chars": 179779,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"m2s4kN_QPQVe\"\n },\n \"sou"
},
{
"path": "practicals/GNN_practical.ipynb",
"chars": 936300,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"obBt50HcYEC_\"\n },\n \"sou"
},
{
"path": "practicals/Indaba_2022_Prac_Template.ipynb",
"chars": 25617,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"m2s4kN_QPQVe\"\n },\n \"source\": [\n \"# **[P"
},
{
"path": "practicals/Introduction_to_ML_using_JAX.ipynb",
"chars": 146464,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"m2s4kN_QPQVe\"\n },\n \"sou"
},
{
"path": "practicals/array_algebra.ipynb",
"chars": 1950433,
"preview": "{\n\t\"cells\":[\n\t\t{\n\t\t\t\"cell_type\":\"markdown\",\n\t\t\t\"metadata\":{},\n\t\t\t\"source\":[\n\t\t\t\t\"# Array Algebra\"\n\t\t\t]\n\t\t},\n\t\t{\n\t\t\t\"cell"
},
{
"path": "practicals/attention_and_transformers.ipynb",
"chars": 197421,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"m2s4kN_QPQVe\"\n },\n \"sou"
},
{
"path": "practicals/deep_generative_models.ipynb",
"chars": 899015,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"m2s4kN_QPQVe\"\n },\n \"sou"
},
{
"path": "practicals/introduction_to_reinforcement_learning.ipynb",
"chars": 136501,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"m2s4kN_QPQVe\"\n },\n \"sou"
}
]
About this extraction
This page contains the full source code of the deep-learning-indaba/indaba-pracs-2022 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (4.3 MB), approximately 1.1M tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.