Full Code of cloneofsimo/d3pm for AI

main 3ceb63725ee2 cached
10 files
215.3 KB
112.3k tokens
62 symbols
1 requests
Download .txt
Showing preview only (222K chars total). Download the full file or copy to clipboard to get everything.
Repository: cloneofsimo/d3pm
Branch: main
Commit: 3ceb63725ee2
Files: 10
Total size: 215.3 KB

Directory structure:
gitextract_jcwmbccq/

├── .gitignore
├── CITATION.cff
├── d3pm_runner.py
├── d3pm_runner_cifar10.py
├── dit.py
├── lm.py
├── lm_deepspeed.py
├── readme.md
├── run_multigpu.sh
└── test.ipynb

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
presetup.sh
setup.sh
cu122py310
data
test
wandb
run.sh
*.pyc
ckpt

================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "Citations would be appreciated if you end up using this tool! I currently go by Simo Ryu"
authors:
- family-names: "Ryu"
  given-names: "Simo"
  orcid: "https://orcid.org/0009-0008-0017-2677"
title: "Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch"
version: 0.0.1
date-released: 2024-04
url: "https://github.com/cloneofsimo/d3pm"

================================================
FILE: d3pm_runner.py
================================================
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm import tqdm

blk = lambda ic, oc: nn.Sequential(
    nn.Conv2d(ic, oc, 5, padding=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
    nn.Conv2d(oc, oc, 5, padding=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
    nn.Conv2d(oc, oc, 5, padding=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
)

blku = lambda ic, oc: nn.Sequential(
    nn.Conv2d(ic, oc, 5, padding=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
    nn.Conv2d(oc, oc, 5, padding=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
    nn.Conv2d(oc, oc, 5, padding=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
    nn.ConvTranspose2d(oc, oc, 2, stride=2),
    nn.GroupNorm(oc // 8, oc),
    nn.LeakyReLU(),
)


class DummyX0Model(nn.Module):

    def __init__(self, n_channel: int, N: int = 16) -> None:
        super(DummyX0Model, self).__init__()
        self.down1 = blk(n_channel, 16)
        self.down2 = blk(16, 32)
        self.down3 = blk(32, 64)
        self.down4 = blk(64, 512)
        self.down5 = blk(512, 512)
        self.up1 = blku(512, 512)
        self.up2 = blku(512 + 512, 64)
        self.up3 = blku(64, 32)
        self.up4 = blku(32, 16)
        self.convlast = blk(16, 16)
        self.final = nn.Conv2d(16, N * n_channel, 1, bias=False)

        self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8)

        self.cond_embedding_1 = nn.Embedding(10, 16)
        self.cond_embedding_2 = nn.Embedding(10, 32)
        self.cond_embedding_3 = nn.Embedding(10, 64)
        self.cond_embedding_4 = nn.Embedding(10, 512)
        self.cond_embedding_5 = nn.Embedding(10, 512)
        self.cond_embedding_6 = nn.Embedding(10, 64)

        self.temb_1 = nn.Linear(32, 16)
        self.temb_2 = nn.Linear(32, 32)
        self.temb_3 = nn.Linear(32, 64)
        self.temb_4 = nn.Linear(32, 512)
        self.N = N

    def forward(self, x, t, cond) -> torch.Tensor:
        x = (2 * x.float() / self.N) - 1.0
        t = t.float().reshape(-1, 1) / 1000
        t_features = [torch.sin(t * 3.1415 * 2**i) for i in range(16)] + [
            torch.cos(t * 3.1415 * 2**i) for i in range(16)
        ]
        tx = torch.cat(t_features, dim=1).to(x.device)

        t_emb_1 = self.temb_1(tx).unsqueeze(-1).unsqueeze(-1)
        t_emb_2 = self.temb_2(tx).unsqueeze(-1).unsqueeze(-1)
        t_emb_3 = self.temb_3(tx).unsqueeze(-1).unsqueeze(-1)
        t_emb_4 = self.temb_4(tx).unsqueeze(-1).unsqueeze(-1)

        cond_emb_1 = self.cond_embedding_1(cond).unsqueeze(-1).unsqueeze(-1)
        cond_emb_2 = self.cond_embedding_2(cond).unsqueeze(-1).unsqueeze(-1)
        cond_emb_3 = self.cond_embedding_3(cond).unsqueeze(-1).unsqueeze(-1)
        cond_emb_4 = self.cond_embedding_4(cond).unsqueeze(-1).unsqueeze(-1)
        cond_emb_5 = self.cond_embedding_5(cond).unsqueeze(-1).unsqueeze(-1)
        cond_emb_6 = self.cond_embedding_6(cond).unsqueeze(-1).unsqueeze(-1)

        x1 = self.down1(x) + t_emb_1 + cond_emb_1
        x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2 + cond_emb_2
        x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3 + cond_emb_3
        x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4 + cond_emb_4
        x5 = self.down5(nn.functional.avg_pool2d(x4, 2))

        x5 = (
            self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2))
            .transpose(1, 2)
            .reshape(x5.shape)
        )

        y = self.up1(x5) + cond_emb_5

        y = (
            self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2))
            .transpose(1, 2)
            .reshape(y.shape)
        )

        y = self.up2(torch.cat([x4, y], dim=1)) + cond_emb_6

        y = (
            self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2))
            .transpose(1, 2)
            .reshape(y.shape)
        )
        y = self.up3(y)
        y = self.up4(y)
        y = self.convlast(y)
        y = self.final(y)

        # reshape to B, C, H, W, N
        y = (
            y.reshape(y.shape[0], -1, self.N, *x.shape[2:])
            .transpose(2, -1)
            .contiguous()
        )

        return y


class D3PM(nn.Module):
    def __init__(
        self,
        x0_model: nn.Module,
        n_T: int,
        num_classes: int = 10,
        forward_type="uniform",
        hybrid_loss_coeff=0.001,
    ) -> None:
        super(D3PM, self).__init__()
        self.x0_model = x0_model

        self.n_T = n_T
        self.hybrid_loss_coeff = hybrid_loss_coeff

        steps = torch.arange(n_T + 1, dtype=torch.float64) / n_T
        alpha_bar = torch.cos((steps + 0.008) / 1.008 * torch.pi / 2)
        self.beta_t = torch.minimum(
            1 - alpha_bar[1:] / alpha_bar[:-1], torch.ones_like(alpha_bar[1:]) * 0.999
        )

        # self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)]
        self.eps = 1e-6
        self.num_classses = num_classes
        q_onestep_mats = []
        q_mats = []  # these are cumulative

        for beta in self.beta_t:

            if forward_type == "uniform":
                mat = torch.ones(num_classes, num_classes) * beta / num_classes
                mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)
                q_onestep_mats.append(mat)
            else:
                raise NotImplementedError
        q_one_step_mats = torch.stack(q_onestep_mats, dim=0)

        q_one_step_transposed = q_one_step_mats.transpose(
            1, 2
        )  # this will be used for q_posterior_logits

        q_mat_t = q_onestep_mats[0]
        q_mats = [q_mat_t]
        for idx in range(1, self.n_T):
            q_mat_t = q_mat_t @ q_onestep_mats[idx]
            q_mats.append(q_mat_t)
        q_mats = torch.stack(q_mats, dim=0)
        self.logit_type = "logit"

        # register
        self.register_buffer("q_one_step_transposed", q_one_step_transposed)
        self.register_buffer("q_mats", q_mats)

        assert self.q_mats.shape == (
            self.n_T,
            num_classes,
            num_classes,
        ), self.q_mats.shape

    def _at(self, a, t, x):
        # t is 1-d, x is integer value of 0 to num_classes - 1
        bs = t.shape[0]
        t = t.reshape((bs, *[1] * (x.dim() - 1)))
        # out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]
        return a[t - 1, x, :]

    def q_posterior_logits(self, x_0, x_t, t):
        # if t == 1, this means we return the L_0 loss, so directly try to x_0 logits.
        # otherwise, we return the L_{t-1} loss.
        # Also, we never have t == 0.

        # if x_0 is integer, we convert it to one-hot.
        if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:
            x_0_logits = torch.log(
                torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps
            )
        else:
            x_0_logits = x_0.clone()

        assert x_0_logits.shape == x_t.shape + (self.num_classses,), print(
            f"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}"
        )

        # Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that.

        # fact1 is "guess of x_{t-1}" from x_t
        # fact2 is "guess of x_{t-1}" from x_0

        fact1 = self._at(self.q_one_step_transposed, t, x_t)

        softmaxed = torch.softmax(x_0_logits, dim=-1)  # bs, ..., num_classes
        qmats2 = self.q_mats[t - 2].to(dtype=softmaxed.dtype)
        # bs, num_classes, num_classes
        fact2 = torch.einsum("b...c,bcd->b...d", softmaxed, qmats2)

        out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)

        t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))

        bc = torch.where(t_broadcast == 1, x_0_logits, out)

        return bc

    def vb(self, dist1, dist2):

        # flatten dist1 and dist2
        dist1 = dist1.flatten(start_dim=0, end_dim=-2)
        dist2 = dist2.flatten(start_dim=0, end_dim=-2)

        out = torch.softmax(dist1 + self.eps, dim=-1) * (
            torch.log_softmax(dist1 + self.eps, dim=-1)
            - torch.log_softmax(dist2 + self.eps, dim=-1)
        )
        return out.sum(dim=-1).mean()

    def q_sample(self, x_0, t, noise):
        # forward process, x_0 is the clean input.
        logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)
        noise = torch.clip(noise, self.eps, 1.0)
        gumbel_noise = -torch.log(-torch.log(noise))
        return torch.argmax(logits + gumbel_noise, dim=-1)

    def model_predict(self, x_0, t, cond):
        # this part exists because in general, manipulation of logits from model's logit
        # so they are in form of x_0's logit might be independent to model choice.
        # for example, you can convert 2 * N channel output of model output to logit via get_logits_from_logistic_pars
        # they introduce at appendix A.8.

        predicted_x0_logits = self.x0_model(x_0, t, cond)

        return predicted_x0_logits

    def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch.Tensor:
        """
        Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model.
        x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1
        """
        t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)
        x_t = self.q_sample(
            x, t, torch.rand((*x.shape, self.num_classses), device=x.device)
        )
        # x_t is same shape as x
        assert x_t.shape == x.shape, print(
            f"x_t.shape: {x_t.shape}, x.shape: {x.shape}"
        )
        # we use hybrid loss.

        predicted_x0_logits = self.model_predict(x_t, t, cond)

        # based on this, we first do vb loss.
        true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)
        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)

        vb_loss = self.vb(true_q_posterior_logits, pred_q_posterior_logits)

        predicted_x0_logits = predicted_x0_logits.flatten(start_dim=0, end_dim=-2)
        x = x.flatten(start_dim=0, end_dim=-1)

        ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)

        return self.hybrid_loss_coeff * vb_loss + ce_loss, {
            "vb_loss": vb_loss.detach().item(),
            "ce_loss": ce_loss.detach().item(),
        }

    def p_sample(self, x, t, cond, noise):

        predicted_x0_logits = self.model_predict(x, t, cond)
        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)

        noise = torch.clip(noise, self.eps, 1.0)

        not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))

        gumbel_noise = -torch.log(-torch.log(noise))
        sample = torch.argmax(
            pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1
        )
        return sample

    def sample(self, x, cond=None):
        for t in reversed(range(1, self.n_T)):
            t = torch.tensor([t] * x.shape[0], device=x.device)
            x = self.p_sample(
                x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
            )

        return x

    def sample_with_image_sequence(self, x, cond=None, stride=10):
        steps = 0
        images = []
        for t in reversed(range(1, self.n_T)):
            t = torch.tensor([t] * x.shape[0], device=x.device)
            x = self.p_sample(
                x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
            )
            steps += 1
            if steps % stride == 0:
                images.append(x)

        # if last step is not divisible by stride, we add the last image.
        if steps % stride != 0:
            images.append(x)

        return images


if __name__ == "__main__":

    N = 2  # number of classes for discretized state per pixel
    d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes=N, hybrid_loss_coeff=0.0).cuda()
    print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
    dataset = MNIST(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Pad(2),
            ]
        ),
    )
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32)

    optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=1e-3)
    d3pm.train()

    n_epoch = 400
    device = "cuda"

    global_step = 0
    for i in range(n_epoch):

        pbar = tqdm(dataloader)
        loss_ema = None
        for x, cond in pbar:
            optim.zero_grad()
            x = x.to(device)
            cond = cond.to(device)

            # discritize x to N bins
            x = (x * (N - 1)).round().long().clamp(0, N - 1)
            loss, info = d3pm(x, cond)

            loss.backward()
            norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.1)

            with torch.no_grad():
                param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])

            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.99 * loss_ema + 0.01 * loss.item()
            pbar.set_description(
                f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
            )
            optim.step()
            global_step += 1

            if global_step % 300 == 1:
                d3pm.eval()

                with torch.no_grad():
                    cond = torch.arange(0, 4).cuda() % 10
                    init_noise = torch.randint(0, N, (4, 1, 32, 32)).cuda()

                    images = d3pm.sample_with_image_sequence(
                        init_noise, cond, stride=40
                    )
                    # image sequences to gif
                    gif = []
                    for image in images:
                        x_as_image = make_grid(image.float() / (N - 1), nrow=2)
                        img = x_as_image.permute(1, 2, 0).cpu().numpy()
                        img = (img * 255).astype(np.uint8)
                        gif.append(Image.fromarray(img))

                    gif[0].save(
                        f"contents/sample_{global_step}.gif",
                        save_all=True,
                        append_images=gif[1:],
                        duration=100,
                        loop=0,
                    )

                    last_img = gif[-1]
                    last_img.save(f"contents/sample_{global_step}_last.png")

                d3pm.train()


================================================
FILE: d3pm_runner_cifar10.py
================================================
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
from tqdm import tqdm

import wandb
from d3pm_runner import D3PM
from dit import DiT_Llama

if __name__ == "__main__":

    wandb.init(project="d3pm_cifar10")

    N = 8  # number of classes for discretized state per pixel
    d3pm = D3PM(
        DiT_Llama(3, N, dim=1024), 1000, num_classes=N, hybrid_loss_coeff=0.0
    ).cuda()
    print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
    dataset = CIFAR10(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        ),
    )
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=16)
    optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=2e-5)
    d3pm.train()

    n_epoch = 4000
    device = "cuda"

    global_step = 0
    for i in range(n_epoch):

        pbar = tqdm(dataloader)
        loss_ema = None
        for x, cond in pbar:
            optim.zero_grad()
            x = x.to(device)
            cond = cond.to(device)

            # discritize x to N bins
            x_cat = (x * (N - 1)).round().long().clamp(0, N - 1)
            loss, info = d3pm(x_cat, cond)

            loss.backward()
            norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 5.0)

            with torch.no_grad():
                param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])

            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.99 * loss_ema + 0.01 * loss.item()

            if global_step % 10 == 0:
                wandb.log(
                    {
                        "train_loss": loss,
                        "train_grad_norm": norm,
                        "train_param_norm": param_norm,
                    }
                )
            pbar.set_description(
                f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
            )
            optim.step()
            global_step += 1

            if global_step % 600 == 1:
                d3pm.eval()

                with torch.no_grad():
                    cond = torch.arange(0, 16).cuda() % 10
                    init_noise = torch.randint(0, N, (16, 3, 32, 32)).cuda()

                    images = d3pm.sample_with_image_sequence(
                        init_noise, cond, stride=40
                    )
                    # image sequences to gif
                    gif = []
                    for image in images:
                        x_from_dataloader = x_cat[:16].cpu() / (N - 1)
                        this_image = image.float().cpu() / (N - 1)
                        all_images = torch.cat([x_from_dataloader, this_image], dim=0)
                        x_as_image = make_grid(all_images, nrow=4)
                        img = x_as_image.permute(1, 2, 0).cpu().numpy()
                        img = (img * 255).astype(np.uint8)
                        gif.append(Image.fromarray(img))

                    gif[0].save(
                        f"contents/sample_{global_step}.gif",
                        save_all=True,
                        append_images=gif[1:],
                        duration=100,
                        loop=0,
                    )

                    last_img = gif[-1]
                    last_img.save(f"contents/sample_{global_step}_last.png")

                    # log images
                    wandb.log(
                        {
                            "sample": wandb.Image(last_img),
                        }
                    )

                d3pm.train()


================================================
FILE: dit.py
================================================
# Code heavilty based on https://github.com/Alpha-VLLM/LLaMA2-Accessory

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class TimestepEmbedder(nn.Module):
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half) / half
        ).to(t.device)
        args = t[:, None] * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
            dtype=next(self.parameters()).dtype
        )
        t_emb = self.mlp(t_freq)
        return t_emb


class LabelEmbedder(nn.Module):
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = int(dropout_prob > 0)
        self.embedding_table = nn.Embedding(
            num_classes + use_cfg_embedding, hidden_size
        )
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
            drop_ids = drop_ids.cuda()
            drop_ids = drop_ids.to(labels.device)
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels)
        return embeddings


class Attention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()

        self.n_heads = n_heads
        self.n_rep = 1
        self.head_dim = dim // n_heads

        self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)

        self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
        self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)

    @staticmethod
    def reshape_for_broadcast(freqs_cis, x):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    @staticmethod
    def apply_rotary_emb(xq, xk, freqs_cis):
        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
        xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
        freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
        return xq_out, xk_out

    def forward(self, x, freqs_cis):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        dtype = xq.dtype

        xq = self.q_norm(xq)
        xk = self.k_norm(xk)

        xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)

        xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        xq, xk = xq.to(dtype), xk.to(dtype)

        output = F.scaled_dot_product_attention(
            xq.permute(0, 2, 1, 3),
            xk.permute(0, 2, 1, 3),
            xv.permute(0, 2, 1, 3),
            dropout_p=0.0,
            is_causal=False,
        ).permute(0, 2, 1, 3)
        output = output.flatten(-2)

        return self.wo(output)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def _forward_silu_gating(self, x1, x3):
        return F.silu(x1) * x3

    def forward(self, x):
        return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))


class TransformerBlock(nn.Module):
    def __init__(
        self,
        layer_id,
        dim,
        n_heads,
        multiple_of,
        ffn_dim_multiplier,
        norm_eps,
    ):
        super().__init__()
        self.dim = dim
        self.head_dim = dim // n_heads
        self.attention = Attention(dim, n_heads)
        self.feed_forward = FeedForward(
            dim=dim,
            hidden_dim=4 * dim,
            multiple_of=multiple_of,
            ffn_dim_multiplier=ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
        self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(min(dim, 1024), 6 * dim, bias=True),
        )

    def forward(self, x, freqs_cis, adaln_input=None):
        if adaln_input is not None:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
                self.adaLN_modulation(adaln_input).chunk(6, dim=1)
            )

            x = x + gate_msa.unsqueeze(1) * self.attention(
                modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
            )
            x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
                modulate(self.ffn_norm(x), shift_mlp, scale_mlp)
            )
        else:
            x = x + self.attention(self.attention_norm(x), freqs_cis)
            x = x + self.feed_forward(self.ffn_norm(x))

        return x


class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(
            hidden_size, patch_size * patch_size * out_channels, bias=True
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
        )
        # # init zero
        nn.init.constant_(self.linear.weight, 0)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class DDiT_Llama(nn.Module):

    def __init__(
        self,
        N=256,
        dim=512,
        n_layers=5,
        n_heads=16,
        multiple_of=256,
        ffn_dim_multiplier=None,
        norm_eps=1e-5,
        learn_gating=False,
    ):
        super().__init__()
        self.N = N
        self.learn_gating = learn_gating
        if self.learn_gating:
            self.out_channel = N * 2
        else:
            self.out_channel = N

        self.embedder = nn.Embedding(N, dim)
        self.t_embedder = TimestepEmbedder(min(dim, 1024))
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    layer_id,
                    dim,
                    n_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                )
                for layer_id in range(n_layers)
            ]
        )
        self.final_layer = FinalLayer(dim, 1, self.out_channel)
        self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 4096)

    def forward(self, x, t, cond=None):
        self.freqs_cis = self.freqs_cis.to(x.device)
        x_onehot = torch.nn.functional.one_hot(x, self.N).to(
            x.device, dtype=next(self.parameters()).dtype
        )
        x = self.embedder(x)
        adaln_input = self.t_embedder(t)

        for layer in self.layers:
            x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)

        x = self.final_layer(x, adaln_input)
        if self.learn_gating:
            x, gate = x.chunk(2, dim=-1)
            return x + x_onehot * (1 + gate).abs()
        else:
            return x + x_onehot


class DiT_Llama(nn.Module):
    def __init__(
        self,
        in_channels=3,
        N=8,
        input_size=32,
        patch_size=2,
        dim=512,
        n_layers=5,
        n_heads=16,
        multiple_of=256,
        ffn_dim_multiplier=None,
        norm_eps=1e-5,
        class_dropout_prob=0.1,
        num_classes=10,
        learn_sigma=True,
    ):
        super().__init__()
        self.N = N
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = N * in_channels * 2
        self.input_size = input_size
        self.patch_size = patch_size

        self.init_conv_seq = nn.Sequential(
            nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
            nn.SiLU(),
            nn.GroupNorm(32, dim // 2),
            nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
            nn.SiLU(),
            nn.GroupNorm(32, dim // 2),
        )

        self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
        nn.init.constant_(self.x_embedder.bias, 0)

        self.t_embedder = TimestepEmbedder(min(dim, 1024))
        self.y_embedder = LabelEmbedder(num_classes, min(dim, 1024), class_dropout_prob)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    layer_id,
                    dim,
                    n_heads,
                    multiple_of,
                    ffn_dim_multiplier,
                    norm_eps,
                )
                for layer_id in range(n_layers)
            ]
        )
        self.final_layer = FinalLayer(dim, patch_size, self.out_channels)

        self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 4096)

    def unpatchify(self, x):
        c = self.out_channels
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def patchify(self, x):
        B, C, H, W = x.size()
        x = x.view(
            B,
            C,
            H // self.patch_size,
            self.patch_size,
            W // self.patch_size,
            self.patch_size,
        )
        x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
        return x

    def forward(self, x, t, y):
        self.freqs_cis = self.freqs_cis.to(x.device)

        x_onehot = torch.nn.functional.one_hot(x, self.N).float().to(x.device)
        x = (2 * x.float() / (self.N - 1)) - 1.0
        x = self.init_conv_seq(x)

        x = self.patchify(x)
        x = self.x_embedder(x)

        t = self.t_embedder(t)  # (N, D)
        y = self.y_embedder(y, self.training)  # (N, D)
        adaln_input = t + y

        for layer in self.layers:
            x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)

        x = self.final_layer(x, adaln_input)
        x = self.unpatchify(x)  # (N, out_channels, H, W)

        x, gate = (
            x.reshape(x.shape[0], -1, self.N * 2, *x.shape[2:])
            .transpose(2, -1)
            .contiguous()
        ).chunk(2, dim=-1)

        return x + x_onehot * (1 + gate).abs()

        # x = (x.reshape(x.shape[0], -1, self.N, *x.shape[2:])
        #     .transpose(2, -1)
        #     .contiguous()
        # )
        # return x

    def forward_with_cfg(self, x, t, y, cfg_scale):
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, y)
        eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

    @staticmethod
    def precompute_freqs_cis(dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis


def DiT_Llama_600M_patch2(**kwargs):
    return DiT_Llama(patch_size=2, dim=256, n_layers=16, n_heads=32, **kwargs)


def DiT_Llama_3B_patch2(**kwargs):
    return DiT_Llama(patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs)


if __name__ == "__main__":
    model = DiT_Llama_600M_patch2()
    model.eval()
    x = torch.randint(0, 8, (4, 3, 32, 32))
    t = torch.randint(0, 1000, (4,))
    y = torch.randint(0, 10, (4,))
    out = model(x, t, y)

    print(out)

    # cuda ver
    model = model.cuda()
    x = x.cuda()
    t = t.cuda()
    y = y.cuda()
    out = model(x, t, y)
    print(out)


================================================
FILE: lm.py
================================================
import math

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import get_scheduler

import wandb
from d3pm_runner import D3PM
from dit import DDiT_Llama


class WikiTextDataset(Dataset):
    def __init__(self, tokenizer=None, type_path="train", max_length=512, debug=False):
        if debug:
            vernum = 2
        else:
            vernum = 103
        self.vernum = vernum
        self.dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train")
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return (
            int(len(self.dataset) * 0.1) if (self.vernum == 103) else len(self.dataset)
        )

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        # logger.info(text)
        if self.tokenizer is not None:
            inputs = self.tokenizer(
                text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )
            input_ids = inputs.input_ids.squeeze()
        else:
            # use byte encoding
            seq = list(text.encode("utf-8"))
            if len(seq) < self.max_length:
                seq += [0] * (self.max_length - len(seq))
            else:
                seq = seq[: self.max_length]
            input_ids = torch.tensor(seq, dtype=torch.long)

        return {"input_ids": input_ids}


if __name__ == "__main__":

    wandb.init(project="d3pm_wiki")

    N = 256
    max_length = 256
    num_train_epochs = 5

    d3pm = D3PM(
        DDiT_Llama(N, dim=512, n_layers=6), 1000, num_classes=N, hybrid_loss_coeff=0.0
    ).cuda()

    print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
    dataset = WikiTextDataset(max_length=max_length, debug=False)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=8)
    optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=2e-4)

    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optim,
        num_warmup_steps=100,
        num_training_steps=num_train_epochs * math.ceil(len(dataloader)),
    )

    d3pm.train()

    device = "cuda"

    global_step = 0
    for i in range(num_train_epochs):

        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:
            optim.zero_grad()
            x = x["input_ids"].to(device)

            # discritize x to N bins

            loss, info = d3pm(x)

            loss.backward()
            norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 5.0)

            with torch.no_grad():
                param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])

            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.99 * loss_ema + 0.01 * loss.item()

            if global_step % 10 == 0:
                wandb.log(
                    {
                        "train_loss": loss,
                        "train_grad_norm": norm,
                        "train_param_norm": param_norm,
                    }
                )

            pbar.set_description(
                f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
            )
            optim.step()
            lr_scheduler.step()
            global_step += 1

            if global_step % 600 == 1:
                d3pm.eval()

                with torch.no_grad():

                    init_noise = torch.randint(0, N, (16, max_length)).cuda()

                    outputs = d3pm.sample_with_image_sequence(
                        init_noise, None, stride=40
                    )
                    gen_outputs = []
                    total = 0
                    # back to sentence, byte to utf-8
                    for _i in range(16):
                        sent = outputs[-1][_i].cpu().tolist()
                        correctly_parsed = True
                        try:
                            sent = b"".join([bytes([i]) for i in sent]).decode("utf-8")
                        except:
                            # if there is error, just unicodec
                            correctly_parsed = False
                            sent = "".join([chr(i) for i in sent])
                        sent = (
                            f"[{_i}] Sample Correctly parsed: "
                            + str(correctly_parsed)
                            + "\n"
                            + sent
                        )
                        total += 1 if correctly_parsed else 0

                        gen_outputs.append(sent)

                    print(sent)
                    # make a nice html to show the generated outputs
                    html_formatted = "<br>".join(gen_outputs)
                    # log text
                    wandb.log(
                        {
                            "generated_text": wandb.Html(html_formatted),
                            "correctly_parsed": total,
                        }
                    )

                d3pm.train()

                if global_step % 3000 == 1:
                    torch.save(d3pm.state_dict(), f"ckpt/d3pm_wiki_{global_step}.pth")

                    print(f"Model saved at {global_step}")


================================================
FILE: lm_deepspeed.py
================================================
import math
import os
import random

import click
import deepspeed
import numpy as np
import torch
from datasets import load_dataset
from deepspeed import get_accelerator
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.utils import logger
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import default_data_collator, get_scheduler

import wandb
from d3pm_runner import D3PM
from dit import DDiT_Llama


class WikiTextDataset(Dataset):
    def __init__(
        self, tokenizer=None, type_path="train", max_seq_length=512, debug=False
    ):

        if debug:
            self.dataset = load_dataset("wikitext", f"wikitext-2-raw-v1", split="test")
        else:
            self.dataset = load_dataset(
                "wikimedia/wikipedia", "20231101.en", split="train"
            )

        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        # logger.info(text)
        if self.tokenizer is not None:
            inputs = self.tokenizer(
                text,
                max_length=self.max_seq_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )
            input_ids = inputs.input_ids.squeeze()
        else:
            # use byte encoding
            seq = list(text.encode("utf-8"))
            if len(seq) < self.max_seq_length:
                seq += [0] * (self.max_seq_length - len(seq))
            else:
                seq = seq[: self.max_seq_length]
            input_ids = torch.tensor(seq, dtype=torch.long)

        return {"input_ids": input_ids}


def _z3_params_to_fetch(param_list):
    return [
        p
        for p in param_list
        if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
    ]


def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
    zero_stage_3 = zero_stage == 3
    os.makedirs(save_dir, exist_ok=True)
    WEIGHTS_NAME = "pytorch_model.bin"
    output_model_file = os.path.join(save_dir, WEIGHTS_NAME)

    model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
    if not zero_stage_3:
        if global_rank == 0:
            torch.save(model_to_save.state_dict(), output_model_file)
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():
            if hasattr(v, "ds_id"):
                with deepspeed.zero.GatheredParameters(
                    _z3_params_to_fetch([v]), enabled=zero_stage_3
                ):
                    v_p = v.data.cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            torch.save(output_state_dict, output_model_file)
        del output_state_dict


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@click.command()
@click.option("--local_rank", default=-1, help="Local rank")
@click.option("--max_seq_length", default=256, help="Max sequence length")
@click.option("--num_train_epochs", default=5, help="Number of training epochs")
@click.option("--learning_rate", default=1e-4, help="Learning rate")
@click.option("--offload", default=False, help="Offload")
@click.option("--train_batch_size", default=1024, help="Train batch size")
@click.option(
    "--per_device_train_batch_size", default=64, help="Per device train batch size"
)
@click.option("--zero_stage", default=2, help="Zero stage")
@click.option("--seed", default=42, help="Seed")
@click.option("--run_name", default=None, help="Run name")
def main(
    local_rank,
    max_seq_length=256,
    num_train_epochs=5,
    learning_rate=1e-4,
    offload=False,
    train_batch_size=512,
    per_device_train_batch_size=64,
    zero_stage=2,
    seed=42,
    run_name=None,
):
    # first, set the seed
    set_seed(seed)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_flash_sdp(False)

    if run_name is None:
        run_name = f"LR:{learning_rate}_max_seq_length:{max_seq_length}_num_train_epochs:{num_train_epochs}_offload:{offload}"

    if local_rank == -1:
        device = torch.device(get_accelerator().device_name())
    else:
        get_accelerator().set_device(local_rank)
        device = torch.device(get_accelerator().device_name(), local_rank)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        deepspeed.init_distributed()

    offload_device = "cpu" if offload else "none"

    ds_config = {
        "train_micro_batch_size_per_gpu": per_device_train_batch_size,
        "train_batch_size": train_batch_size,
        "zero_optimization": {
            "stage": zero_stage,
            "offload_param": {"device": offload_device},
            "offload_optimizer": {"device": offload_device},
            "stage3_param_persistence_threshold": 1e4,
            "stage3_max_live_parameters": 3e7,
            "stage3_prefetch_bucket_size": 3e7,
            "memory_efficient_linear": False,
        },
        "bfloat16": {"enabled": True},
        "gradient_clipping": 1.0,
    }

    torch.distributed.barrier()

    global_rank = torch.distributed.get_rank()

    ##### DEFINE model, dataset, sampler, dataloader, optim, schedular
    N = 256
    with deepspeed.zero.Init(enabled=(zero_stage == 3)):

        d3pm = D3PM(
            DDiT_Llama(N, dim=768, n_layers=8),
            1000,
            num_classes=N,
            hybrid_loss_coeff=0.0,
        ).cuda()

    total_params = sum(p.numel() for p in d3pm.parameters())
    size_in_bytes = total_params * 4
    size_in_gb = size_in_bytes / (1024**3)
    logger.info(
        f"Model Size: {size_in_bytes}, {size_in_gb} GB, Total Param Count: {total_params / 1e6} M"
    )

    dataset = WikiTextDataset(max_seq_length=max_seq_length, debug=False)

    train_sampler = (
        RandomSampler(dataset)
        if local_rank == -1
        else DistributedSampler(dataset, seed=seed)
    )

    dataloader = DataLoader(
        dataset,
        collate_fn=default_data_collator,
        sampler=train_sampler,
        batch_size=per_device_train_batch_size,
    )

    optimizer = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=learning_rate)

    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=num_train_epochs * math.ceil(len(dataloader)),
    )

    d3pm.train()

    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=d3pm, config=ds_config, lr_scheduler=lr_scheduler, optimizer=optimizer
    )

    global_step = 0

    ##### actual training loop

    if global_rank == 0:
        wandb.init(
            project="d3pm_wiki",
            name=run_name,
            config={
                "N": N,
                "max_seq_length": max_seq_length,
                "num_train_epochs": num_train_epochs,
                "learning_rate": learning_rate,
                "offload": offload,
                "train_batch_size": train_batch_size,
                "per_device_train_batch_size": per_device_train_batch_size,
                "zero_stage": zero_stage,
                "seed": seed,
            },
        )

    for i in range(num_train_epochs):
        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:

            x = x["input_ids"].to(model_engine.device)

            # discritize x to N bins

            loss, info = model_engine(x)
            model_engine.backward(loss)
            model_engine.step()

            get_accelerator().empty_cache()
            norm = model_engine.get_global_grad_norm()

            if global_step % 10 == 0:
                if global_rank == 0:
                    wandb.log({"train_loss": loss, "train_grad_norm": norm})

            pbar.set_description(
                f"norm: {norm}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
            )

            global_step += 1

            if global_step % 600 == 1:
                d3pm.eval()

                with torch.no_grad():
                    init_noise = torch.randint(0, N, (16, max_seq_length)).cuda()

                    outputs = d3pm.sample_with_image_sequence(
                        init_noise, None, stride=40
                    )
                    gen_outputs = []
                    total = 0
                    # back to sentence, byte to utf-8
                    for _i in range(16):
                        sent = outputs[-1][_i].cpu().tolist()
                        correctly_parsed = True
                        try:
                            sent = b"".join([bytes([i]) for i in sent]).decode("utf-8")
                        except:
                            # if there is error, just unicodec
                            correctly_parsed = False
                            sent = "".join([chr(i) for i in sent])
                        sent = (
                            f"[{_i}] Sample Correctly parsed: "
                            + str(correctly_parsed)
                            + "\n"
                            + sent
                        )
                        total += 1 if correctly_parsed else 0

                        gen_outputs.append(sent)

                    print(sent)
                    model_engine.train()

                    # make a nice html to show the generated outputs
                    html_formatted = "<br>".join(gen_outputs)
                    # log text
                    if global_rank == 0:
                        wandb.log(
                            {
                                "generated_text": wandb.Html(html_formatted),
                                "correctly_parsed": total,
                            }
                        )
                    if global_step % 3000 == 1:
                        save_zero_three_model(
                            model_engine, global_rank, "./ckpt", zero_stage=zero_stage
                        )

                        print(f"Model saved at {global_step}")


if __name__ == "__main__":
    main()


================================================
FILE: readme.md
================================================
<p align="center">
  <img src="contents/output.gif" alt="large" width="400">
  <img src="contents/cifar_best.gif" alt="large" width="200">
</p>


# Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch


<p align="center">
  <img src="contents/best.gif" alt="small" width="400">
  <img src="contents/best.png" alt="small" width="400">
</p>


**Special thanks to [fal.ai](https://fal.ai/) for the compute resources for this project.**


This is minimal (400 LOC), but fully faithful implementation of a D3PM [Structured Denoising Diffusion Models in Discrete State-Spaces](https://arxiv.org/abs/2107.03006). in pytorch.

I have tried to keep the code as simple as possible with much comments and explanation that is somewhat lacking on the original jax implementation, so that it is easy to understand. As far as I know, this is the first, faithful reimplementation of D3PM in pytorch. (Please correct me if I am wrong). Of course, this implementation was heavily based on the [official implementation](https://github.com/google-research/google-research/tree/master/d3pm/images).

Difference between this implementation and the official implementation:

* This one has conditional sampling, so as you can see, generations are class-conditioned.
* This one uses rather different/simple model architecture.
* This one simplfies the official implementation very very much, so it is 400 LOC.
* This one does not use truncated logistic reparameterization, but you can use that if you wish.
* Only has uniform sample with inverse-linear beta scheudule, but you can change that with couple loc as well.

## Usage


Following is completely self-contained example.

```bash
python d3pm_runner.py
```

Following uses dit.py, for CIFAR-10 dataset.
  
```bash
python d3pm_runner_cifar.py
```

## Requirements

Install torch, torchvision, pillow, tqdm

```bash
pip install torch torchvision pillow tqdm
```

## Citation

This implementation:

```bibtex
@misc{d3pm_pytorch,
  author={Simo Ryu},
  title={Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch},
  year={2024},
  howpublished={\url{https://github.com/cloneofsimo/d3pm}}
}
```

Original Paper:

```bibtex
@article{austin2021structured,
  title={Structured denoising diffusion models in discrete state-spaces},
  author={Austin, Jacob and Johnson, Daniel D and Ho, Jonathan and Tarlow, Daniel and Van Den Berg, Rianne},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  pages={17981--17993},
  year={2021}
}
```


================================================
FILE: run_multigpu.sh
================================================
export WORLD_SIZE=$(nvidia-smi -L | wc -l)
deepspeed --num_gpus $WORLD_SIZE lm_deepspeed.py --learning_rate 1e-4

================================================
FILE: test.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "def get_logits_from_logistic_pars(loc, log_scale, num_classes = 10):\n",
    "    loc = loc.unsqueeze(-1)\n",
    "    log_scale = log_scale.unsqueeze(-1)\n",
    "\n",
    "    inv_scale = (-log_scale + 2.0).exp()\n",
    "    \n",
    "    bin_width = 2.0 / (num_classes - 1)\n",
    "    bin_centers = torch.linspace(-1.0, 1.0, num_classes).to(loc.device)\n",
    "    bin_centers = bin_centers.reshape((*loc.shape, num_classes))\n",
    "    bin_centers = bin_centers - loc\n",
    "    log_cdf_min = -torch.log1p((-inv_scale * (bin_centers - 0.5 * bin_width)).exp())\n",
    "    log_cdf_plus = -torch.log1p((-inv_scale * (bin_centers + 0.5 * bin_width)).exp())\n",
    "    logits = log_minus_exp(log_cdf_plus, log_cdf_min)\n",
    "    return logits\n",
    "\n",
    "def log_minus_exp(a, b, epsilon=1.e-6):\n",
    "    return a + torch.log1p(-torch.exp(b - a) + epsilon)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "loc = torch.randn(1, 1, 1, 1).clip(-1, 1)\n",
    "log_scale = torch.randn(1, 1, 1, 1)\n",
    "\n",
    "num_classes = 10\n",
    "logits = get_logits_from_logistic_pars(loc, log_scale, num_classes)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.4609, -1.4102, -1.8315, -2.5864, -3.5119, -4.5085, -5.5319, -6.5647,\n",
       "         -7.6002, -8.6343]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits[0,0,0,0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGiCAYAAADJO+2bAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABQWUlEQVR4nO3deVxU5f4H8M+wzLAIw76og4IbbriLoKWVaaalVlZmKmau3MzsVnhvZrahV+t2f+aWuWUZmblUZqallor7rqCiKIuACDIDCAPMPL8/zClSEXTgOQOf9+t1Xq/m8JzhezjCfDrnWVRCCAEiIiIiBbKTXQARERHR7TCoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYkkPKvn5+Zg8eTIaNWoEZ2dnREZGYv/+/bLLIiIiIgWQHlRefPFFbNmyBStXrsTx48fRp08f9O7dG+np6bJLIyIiIslUMhclLCoqgpubGzZs2ID+/ftb9nfq1An9+vXDe++9J6s0IiIiUgAHmd+8rKwMJpMJTk5O5fY7Oztj586dtzzGaDTCaDRaXpvNZuTm5sLb2xsqlapa6yUiIiLrEEIgPz8f9evXh51dBQ94hGQRERGiZ8+eIj09XZSVlYmVK1cKOzs70bx581u2nz59ugDAjRs3bty4casFW2pqaoU5QeqjHwA4d+4cXnjhBfz222+wt7dHx44d0bx5cxw8eBAJCQk3tf/7HRW9Xo+goCCkpqbC3d29JksnIiKiu2QwGKDT6ZCXlwetVnvbdlIf/QBAkyZNsGPHDhQWFsJgMCAwMBDPPPMMQkJCbtleo9FAo9HctN/d3Z1BhYiIyMbcqduG9FE/N7i6uiIwMBBXr17F5s2bMXDgQNklERERkWTS76hs3rwZQgi0aNECSUlJeO211xAaGopRo0bJLo2IiIgkk35HRa/XIzo6GqGhoRgxYgR69OiBzZs3w9HRUXZpREREJJn0zrT3ymAwQKvVQq/Xs48KERGRjajs57f0OypEREREt8OgQkRERIrFoEJERESKxaBCREREisWgQkRERIrFoEJERESKxaBCREREisWgQkRERIolNaiYTCZMmzYNwcHBcHZ2RpMmTfDuu+/CxuegIyIiIiuRutbPrFmzsGDBAqxYsQKtW7fGgQMHMGrUKGi1WkyaNElmaURERKQAUoPK7t27MXDgQPTv3x8A0LhxY3z11VfYt2+fzLKIiIhIIaQ++omMjMQvv/yCM2fOAACOHj2KnTt3ol+/frc9xmg0wmAwlNuIiIiodpJ6RyUmJgYGgwGhoaGwt7eHyWTC+++/j2HDht32mNjYWMyYMaMGqyQiIiJZpN5RWb16Nb788kusWrUKhw4dwooVKzBnzhysWLHitsdMnToVer3esqWmptZgxURERFSTVELiEBudToeYmBhER0db9r333nv44osvkJiYWKn3qOwy0URERKQclf38lnpH5dq1a7CzK1+Cvb09zGazpIqIiIhISaT2UXnsscfw/vvvIygoCK1bt8bhw4fx0Ucf4YUXXpBZFhERESmE1Ec/+fn5mDZtGtatW4fLly+jfv36GDp0KN566y2o1epKvQcf/RAREdmeyn5+Sw0q1sCgQkREZHtsoo8KERERUUUYVIiIiEixGFSIiIhIsRhUiIiISLEYVIiIiEixGFSIiIhIsRhUiIiISLEYVIiIiEixGFSIiIhIsaQHlcaNG0OlUt20/XVFZSIiIqqbpC5KCAD79++HyWSyvD5x4gQefvhhDBkyRGJVREREpATSg4qvr2+51zNnzkSTJk3Qs2dPSRURERGRUkgPKn9VUlKCL774AlOmTIFKpbplG6PRCKPRaHltMBhqqjwiIiKqYdL7qPzV+vXrkZeXh6ioqNu2iY2NhVartWw6na7mCiQiIqIapRJCCNlF3NC3b1+o1Wp8//33t21zqzsqOp3ujstEExERkXIYDAZotdo7fn4r5tHPxYsXsXXrVqxdu7bCdhqNBhqNpoaqIiIiIpkU8+hn2bJl8PPzQ//+/WWXQkRERAqhiKBiNpuxbNkyjBw5Eg4OirnJQ0RERJIpIqhs3boVKSkpeOGFF2SXQkRERAqiiNsXffr0gYL69BIREZFCKOKOChEREdGtMKgQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYkkPKunp6Xj++efh7e0NZ2dntG3bFgcOHJBdFhERESmA1Jlpr169iu7du+OBBx7Apk2b4Ovri7Nnz8LT01NmWURERKQQUoPKrFmzoNPpsGzZMsu+4OBgiRURERGRkkh99PPdd9+hc+fOGDJkCPz8/NChQwcsXry4wmOMRiMMBkO5jYiIiGonqUHl/PnzWLBgAZo1a4bNmzdjwoQJmDRpElasWHHbY2JjY6HVai2bTqerwYqJiIioJqmExGWL1Wo1OnfujN27d1v2TZo0Cfv370d8fPwtjzEajTAajZbXBoMBOp0Oer0e7u7u1V4zERER3TuDwQCtVnvHz2+pd1QCAwPRqlWrcvtatmyJlJSU2x6j0Wjg7u5ebiMiIqLaSWpQ6d69O06fPl1u35kzZ9CoUSNJFREREZGSSA0qr7zyCvbs2YMPPvgASUlJWLVqFT799FNER0fLLIuIiIgUQmpQ6dKlC9atW4evvvoKbdq0wbvvvouPP/4Yw4YNk1kWERERKYTUzrTWUNnOOERERKQcNtGZloiIiKgiDCpERESkWAwqREREpFgMKkRERKRYDCpERESkWAwqREREpFgMKkRERKRYDCpERESkWAwqREREpFjSg8rbb78NlUpVbgsNDZVdFhERESmAg+wCAKB169bYunWr5bWDgyLKIiIiIskUkQgcHBwQEBBQqbZGoxFGo9Hy2mAwVFdZREREJJn0Rz8AcPbsWdSvXx8hISEYNmwYUlJSbts2NjYWWq3Wsul0uhqslIiIiGqS9NWTN23ahIKCArRo0QIZGRmYMWMG0tPTceLECbi5ud3U/lZ3VHQ6HVdPJiIisiGVXT1ZelD5u7y8PDRq1AgfffQRRo8efcf2lT1RIiIiUo7Kfn4r4tHPX3l4eKB58+ZISkqSXQoRERFJprigUlBQgHPnziEwMFB2KURERCSZ9KDyz3/+Ezt27MCFCxewe/duDB48GPb29hg6dKjs0oiIiEgy6cOT09LSMHToUOTk5MDX1xc9evTAnj174OvrK7s0IiIikkx6UImLi5NdAhERESmU9Ec/RERERLfDoEJERESKxaBCREREisWgQkRERIrFoEJERESKxaBCREREisWgQkRERIrFoEJERESKpaigMnPmTKhUKkyePFl2KURERKQAigkq+/fvx6JFixAWFia7FCIiIlIIRQSVgoICDBs2DIsXL4anp6fscoiIiEghFBFUoqOj0b9/f/Tu3fuObY1GIwwGQ7mNiIiIaidFLEp46NAh7N+/v1LtY2NjMWPGjGquioiIiJRA6h2V1NRUvPzyy/jyyy/h5ORUqWOmTp0KvV5v2VJTU6u5SiIiIpJFJYQQsr75+vXrMXjwYNjb21v2mUwmqFQq2NnZwWg0lvvarRgMBmi1Wuj1eri7u1d3yURERGQFlf38lvro56GHHsLx48fL7Rs1ahRCQ0Pxxhtv3DGkEBERUe0mNai4ubmhTZs25fa5urrC29v7pv1ERERU9yhi1A8RERHRrUgf9fN327dvl10CERERKQTvqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWJJDyoLFixAWFgY3N3d4e7ujoiICGzatEl2WURERKQA0oNKw4YNMXPmTBw8eBAHDhzAgw8+iIEDB+LkyZOySyMiIiLJVEIIIbuIv/Py8sLs2bMxevTom75mNBphNBotrw0GA3Q6HfR6Pdzd3WuyTCIiIrpLBoMBWq32jp/f0u+o/JXJZEJcXBwKCwsRERFxyzaxsbHQarWWTafT1XCVREREVFMUcUfl+PHjiIiIQHFxMerVq4dVq1bh0UcfvWVb3lEhIiKyfZW9o+JQgzXdVosWLXDkyBHo9XqsWbMGI0eOxI4dO9CqVaub2mo0Gmg0GglVEhERUU1TxB2Vv+vduzeaNGmCRYsW3bFtZRMZERERKYdN9lG5wWw2l3u8Q0RERHWT9Ec/U6dORb9+/RAUFIT8/HysWrUK27dvx+bNm2WXRkRERJJJDyqXL1/GiBEjkJGRAa1Wi7CwMGzevBkPP/yw7NKIiIhIMulBZcmSJbJLICIiIoVSZB8VIiIiIoBBhYiIiBSMQYWIiIgUi0GFiIiIFItBhYiIiBSLQYWIiIgUi0GFiIiIFItBhYiIiBRLelCJjY1Fly5d4ObmBj8/PwwaNAinT5+WXRYREREpgPSgsmPHDkRHR2PPnj3YsmULSktL0adPHxQWFsoujYiIiCRTCSGE7CL+Kjs7G35+ftixYwfuv//+O7av7DLRREREpByV/fyWvtbP3+n1egCAl5fXLb9uNBphNBotrw0GQ43URURERDVP+qOfvzKbzZg8eTK6d++ONm3a3LJNbGwstFqtZdPpdDVcJREREdUURT36mTBhAjZt2oSdO3eiYcOGt2xzqzsqOp2Oj36IiIhsiM09+vnHP/6BH374Ab/99tttQwoAaDQaaDSaGqyMiIiIZJEeVIQQeOmll7Bu3Tps374dwcHBsksiIiIihZAeVKKjo7Fq1Sps2LABbm5uyMzMBABotVo4OztLro6IiIhkkt5HRaVS3XL/smXLEBUVdcfjOTyZiIjI9thMHxUF9eUlIiIihVHU8GQiIiKiv2JQISIiIsViUCEiIiLFYlAhIiIixWJQISIiIsViUCEiIiLFYlAhIiIixWJQISIiIsViUCEiIiLFkh5UfvvtNzz22GOoX78+VCoV1q9fL7skIiIiUgjpQaWwsBDt2rXDvHnzZJdCRERECiN9rZ9+/fqhX79+lW5vNBphNBotrw0GQ3WURURERAog/Y5KVcXGxkKr1Vo2nU4nuyQiIiKqJjYXVKZOnQq9Xm/ZUlNTZZdERERE1UT6o5+q0mg00Gg0sssgIiKiGmBzd1SIiIio7mBQISIiIsWS/uinoKAASUlJltfJyck4cuQIvLy8EBQUJLEyIiIikk16UDlw4AAeeOABy+spU6YAAEaOHInly5dLqoqIiIiUQHpQ6dWrF4QQsssgIiIiBWIfFSIiIlIsBhUiIiJSLAYVIiIiUiwGFSIiIlIsBhUiIiJSLAYVIiIiUiwGFSIiIlIsBhUiIiJSLEUElXnz5qFx48ZwcnJCeHg49u3bJ7skIiIiUgDpQeXrr7/GlClTMH36dBw6dAjt2rVD3759cfnyZdmlERERkWTSg8pHH32EMWPGYNSoUWjVqhUWLlwIFxcXLF26VHZpREREJJnUoFJSUoKDBw+id+/eln12dnbo3bs34uPjb3mM0WiEwWAotxEREVHtJDWoXLlyBSaTCf7+/uX2+/v7IzMz85bHxMbGQqvVWjadTlcTpRIREZEE0h/9VNXUqVOh1+stW2pqquySiIiIqJo4yPzmPj4+sLe3R1ZWVrn9WVlZCAgIuOUxGo0GGo2mJsojIiIiyaTeUVGr1ejUqRN++eUXyz6z2YxffvkFEREREisjIiIiJZB6RwUApkyZgpEjR6Jz587o2rUrPv74YxQWFmLUqFGySyMiIiLJpAeVZ555BtnZ2XjrrbeQmZmJ9u3b46effrqpgy0RERHVPSohhJBdxL0wGAzQarXQ6/Vwd3eXXQ4RERFVQmU/v21u1A8RERHVHQwqREREpFgMKkRERKRYDCpERESkWAwqREREpFgMKkRERKRY0udRISICgJIyM/RFpZbN/LeZE+ztVNA6O1o2R3v+fxZRXcCgQkQ1QgiBLIMRJy/pkXylECm5165vOdeQaSjGtRJTld7PVW2PAK0Tgrxc0MjbFUFeLgj2cUXrBu7wc3OqprMgoprGoEJE1aK41IT9F3Jx8OJVHE/T41i6Htn5xgqPUakAdydHuDs7wNGu/B2TEpMZhqJSGIrLAACFJSacyy7EuexCANnl2ga4O6FtQy3CGmjRqZEnOjX2hMbB3qrnR0Q1Q2pQef/997Fx40YcOXIEarUaeXl5MsshontgNgucvGTA70nZ2JV0BfsvXEVJmblcG3s7FZr51UMT33oI8na5fjfEywX1PZzh4eIINydH2NupKvw+JrOAoagUeUWlSL9ahIu5hZY7M0mXC5CUXYBMQzEyTxVjy6nrK7M7Odqha7A3ejT1xn3NfBEa4AaVquLvQ0TKIHUK/enTp8PDwwNpaWlYsmTJXQUVTqFPJI/ZLHA49So2HsvEphMZyNAXl/t6oNYJ3UK80a6hFm0beqBVoDuc1dV7Z6PQWIZTGQYcS9PjaGoe9pzPweW/3cnReTnj0TaBeLRtIMIaahlaiCSo7Oe3Itb6Wb58OSZPnlypoGI0GmE0/vlHx2AwQKfTMagQ1aAzWfn4en8qNh7LQKbhz3DiqrZHRBMf3NfMBz2a+SDEx1V6CBBC4OzlAvx+9gp2ns1G/PkcFJf+eaenoaczBoTVx9OdGyLEt57ESonqlsoGFZvroxIbG4sZM2bILoOozik0lmHjsQzE7U/BoZQ8y/56Ggf0bumHR9sG4v7mvnByVFZfEJVKheb+bmju74bRPYJxraQM209nY+PxDPyacBlpV4uwcMc5LNxxDuHBXni2qw792gQq7jyI6ireUSGiCl3MKcTSncn49lA6CozXO7La26nwUKgfnurUUJHhpLKKSkzYdvoy1hxMw/bTl2H+46+hu5MDnumiQ1T3YDTwcJZbJFEtJe2OSkxMDGbNmlVhm4SEBISGht7V+2s0Gmg0mrs6logq7+DFXCz+LRmbT2Xixv/ONPZ2wTNdgvBkpwa1Ygiws9oej7a93lflUl4R1hxMw9f7U5GeV4TFvydj6a4L6N82EC/eF4ywhh6yyyWqk6x+RyU7Oxs5OTkVtgkJCYFarba8rsodlb9jZ1oi6xFCYMeZbPzfL2fLPd7p1cIXo3sEo0dTH+l9Tqqb2Syw/cxlfPZ7Mnaf+/NvWbcQL0x6qBkiQrxr/c+AqCZIu6Pi6+sLX19fa78tEVWjGwHl461ncSQ1DwCgtrfD4A4NMPq+YDT3d5NbYA2ys1PhwVB/PBjqjxPpeizZmYzvj17CnvO52HN+L7oGe2Fy72aIbOIju1SiOkFqZ9qUlBTk5uYiJSUFJpMJR44cAQA0bdoU9eqx9z1RTdh59grm/HzaElCcHO0wvFsjjLk/pFY83rkXbRpo8d9n2uP1R1pg4fZz+GpfKvYl5+K5xXsRHuyFf/ZtgS6NvWSXSVSrSe1MGxUVhRUrVty0f9u2bejVq1el3oOPfojuTkKGAbGbEvHbmeuzut4IKGPvbwJfN/YDu5UMfREWbD+HuH2pKDFdH+Lcp5U/3ugXiiYc2kxUJTY1j8q9YFAhqppMfTE+/Pk01hxKgxCAo70Kw8IbIfqBpgwolZShL8LcX5Pw9f5UmMwC9nYqDO2qw8sPNefPkKiSGFSIqJziUhMW/3Ye87YnWSY86982EK8/0gKNvF0lV2ebki7nY+amRGxNuAzg+oR3L/duhqjIYKgduLozUUUYVIjIYuupLLzzwymk5F4DAHRp7Il/PdoSHYI8JVdWO+w5n4MPfkzAsTQ9AKCJrytmPN4GPZqxwy3R7TCoEBEuXCnEOz+cwq+J1/+PP8DdCf/u3xIDwgI5xNbKzGaBNYfSMGtTInIKSwAA/doE4M0BrThpHNEtMKgQ1WGlJjM+/e08/vfLWZSUmeFor8LoHiF46cGmcNXY3MoZNkVfVIr/bjmDz+MvwCwAF7U9Xu3TAlGRje+4MjRRXcKgQlRHHUq5iqnfHsfprHwAQI+mPpgxsDVHpdSwhAwD3tpwAvsvXAUAhDXU4oPBbdGmgVZyZUTKwKBCVMcUGMvwn58SsXLPRQgBeLmq8daAVhjYvj4f80hiNgvE7U9F7KYE5BeXwd5OhRd7BGNy7+ZwVtvm+khE1sKgQlSH7Eq6gtfXHEN6XhEA4MmODfHv/i3h5aq+w5FUEy4bijHj+1PYeDwDABDs44o5Q8LQqREni6O6i0GFqA4oMJYh9scEfLk3BQDQ0NMZM58I42gThdp6Kgv/Xn8cWQYjVCpgdPdg/LNvC5tdfZroXlT281vaQP8LFy5g9OjRCA4OhrOzM5o0aYLp06ejpKREVklENmX3uSt45OPfLCFleLdG2Dz5foYUBevdyh8/T+6JJzs2hBDAZzuT8ej/fsfBi1dll0akWNK6/ycmJsJsNmPRokVo2rQpTpw4gTFjxqCwsBBz5syRVRaR4hWXmjB782ks2ZkM4PpdlP88GYbIpgwotkDr4ogPn26H/mEBiPn2OM5fKcSQhbsR/UBTTHqoGRztOVEc0V8p6tHP7NmzsWDBApw/f77Sx/DRD9Ulpy4ZMPnrwziTVQAAeC48CP96tCXqccixTdJfK8Xb35/EusPpAIC2fyyC2NSPI7So9lP8o59b0ev18PKquHOZ0WiEwWAotxHVdiazwMId5zBw3k6cySqATz01lkZ1xgeD2zKk2DCtiyP++0x7zHuuI7TOjjierseAub/j8/gLUND/QxJJpZigkpSUhLlz52LcuHEVtouNjYVWq7VsOp2uhiokkiNTX4xhn+3BzE2JKDUJPNzKH5sn348HQ/1ll0ZW0j8sEJsn34/7mvmguNSMtzacxOgVB5BTYJRdGpF0Vn/0ExMTg1mzZlXYJiEhAaGhoZbX6enp6NmzJ3r16oXPPvuswmONRiOMxj9/eQ0GA3Q6HR/9UK205VQWXltzFHnXSuGitsdbA1rhmS46zotSS5nNAiviLyB2UyJKyszwddPgv0+3ZwdpqpWkDU/Ozs5GTk5OhW1CQkKgVl+f3+HSpUvo1asXunXrhuXLl8POrmo3edhHhWqj4lITYn9MwIr4iwCANg3c8X/PdkAIZ5etExIyDJj01WGcvVwAlQoYd38TvNqnOTvaUq1iE/OopKen44EHHkCnTp3wxRdfwN6+6nMJMKhQbZN0uQD/WHUIiZnXp8Afc9/1uTY0Dpxroy4pKjHhvY2nLMPP2zXU4pPnOkLn5SK5MiLrUHxQSU9PR69evdCoUSOsWLGiXEgJCAio9PswqFBtsvZQGt5cfwLXSkzwqafGnCHt0KuFn+yySKKfTmTgjW+PQ19UCjcnB8x+KgyPtAmUXRbRPVN8UFm+fDlGjRp1y69VpSQGFaoNikpMeGvDCXxzMA0AENnEGx8/2x5+bk6SKyMlSM8rwkurDuFQSh4AYGREI/yrf0veZSObpvigYi0MKmTrki7nY+KXh3Am63p/hMkPNcc/HmwKezt2mKU/lZrMmPPzaSzacX2eqTYN3DHvuY5o5O0quTKiu2OT86gQ1TUbjqTjsbm7cCarAL5uGnz5Yjhe7t2MIYVu4mhvh6n9WmJpVGd4uDjiRLoBA+buxM8nM2WXRlStGFSIJDCWmTBt/Qm8HHcERaUmRDbxxo+T7kNkEw5DpYo9GOqPHyfdh45BHsgvLsPYlQcR+2MCykxm2aURVQsGFaIalnb1Gp5eGI+Ve64PPX7pwaZYOTocvm4ayZWRrajv4Yy4sRF4oXswAGDRb+fx3Gd7cdlQLLkyIutjUCGqQTvOZGPA3J04mqaH1tkRy6K64NU+Lfioh6pM7WCHtx5rhfnDOqKexgH7knPx6P/txN7zFc9jRWRrGFSIaoDZLDD3l7OIWrYPeddK0a6hFhsn9cADoRx6TPfm0baB+O4f3dHC3w1XCox47rO9+Oz381wriGoNBhWiaqYvKsWYzw/gwy1nIMT1FY9Xj49AQ09O3EXWEeJbD+uiIzGwfX2YzALvbUzAS18dRqGxTHZpRPeMy64SVaPETAPGrTyIiznXoHaww3uD2uDpzlxIk6zPRe2Aj59pjw46D7y3MQE/HMvAmax8LHy+E5deIJvGOypE1eT7o5cweN5uXMy5hoaezlg7IZIhhaqVSqVCVPdgxI3tBj83Dc5kFWDgJ7vwS0KW7NKI7prUoPL4448jKCgITk5OCAwMxPDhw3Hp0iWZJRHdszKTGR/8eP3We1GpCfc188H3/+iBNg20skujOqJzYy/8MKkHujT2RL6xDKNXHMDHW8/AbGa/FbI9UoPKAw88gNWrV+P06dP49ttvce7cOTz11FMySyK6J7mFJRi5bB8+/e367KETejXB8lFd4emqllwZ1TV+bk748sVuGBHRCADw8dazGLvyIAzFpZIrI6oaRU2h/91332HQoEEwGo1wdHSs1DGcQp+U4uQlPcZ+fhDpeUVwUdtj9lPt0D+Mi8eRfN8cSMW/159ASZkZIT6u+HREZzT1Y78VksvmptDPzc3Fl19+icjIyApDitFohMFgKLcRyfb90Ut4csFupOcVoZG3C9ZN7M6QQooxpLMOa8ZHIFDrhPNXCjF4HvutkO2QHlTeeOMNuLq6wtvbGykpKdiwYUOF7WNjY6HVai2bTsfOiSSPySww66dEvPTVYRSXmtGzuS++i+6BFgFusksjKiesoQe+f6kHujb2Qr6xDC9+fgCf/HqW862Q4ln90U9MTAxmzZpVYZuEhASEhoYCAK5cuYLc3FxcvHgRM2bMgFarxQ8//ACV6tYzdRqNRhiNRstrg8EAnU7HRz9U4/RFpXg57jC2n84GAIzrGYLX+4ZylllStJIyM9794ZRlCYdH2wZg9lPt4KrhbBVUsyr76MfqQSU7Oxs5ORVP4RwSEgK1+ubOhWlpadDpdNi9ezciIiIq9f3YR4VkOJddgDErDuD8lUI4Odph1pNhGNi+geyyiCotbl8Kpm04gVKTQGiAGxaP6AydFychpJpT2c9vq0doX19f+Pr63tWxZvP11T//eseESGm2nb6MSasOI99YhgYezlg0vBOHHpPNebZrEJr518O4lYeQmJmPgfN2Yf6wjugW4i27NKJypI362bt3L/bv348ePXrA09MT586dw7Rp05CVlYWTJ09Co6ncSrK8o0I1RQiBT387j5k/JUIIoGtjL8x/viN86nHVY7JdGfoijP38II6n6+Fgp8L0x1tjeLdGssuiOkDxo35cXFywdu1aPPTQQ2jRogVGjx6NsLAw7Nixo9IhhaimFJea8MrXRxC76XpIGdo1CF+8GM6QQjYvUOuMb8ZH4PF29VFmFpi2/gT+te44SsrMsksjAqCweVTuBu+oUHXLMhRj7OcHcDRND3s7Fd5+rBWe79both2+iWyREAILd5zHfzZfD+PhwV5Y8HwneHGyQqomir+jQmQLjqXl4fFPduJomh4eLo5YOborhkc0ZkihWkelUmFCryb4bERn1NM4YG9yLgbO24nTmfmyS6M6jkGF6Da+O3oJQxbGI8tgRDO/etgQ3R2RTXxkl0VUrR5q6Y+1EyMR5OWC1NwiPDF/F7ac4uRwJA+DCtHfmM0CH/58GpO+OgxjmRkPhfph7cRINPJ2lV0aUY1o7u+GDdHd0S3EC4UlJoxdeQDztydxcjiSgkGF6C8KjWWY8OVBzP01CcD1Sdw+HdEZbk6VW3uKqLbwdFVj5ehwPN8tCEIA//npNF5dfRTFpSbZpVEdw6BC9IdLeUUYsjAem09mQW1vhw+HtMPUfi050yzVWY72dnhvUFu8O7A17O1UWHs4HUMX70F2Pue6oprDoEIE4FDKVTz+yS6cyjDAp54aX40Nx5OdGsoui0gRhkc0xopRXeHu5IDDKXkY+MlOnLykl10W1REMKlTnrT+cjmc/3YMrBUaEBrhhfXR3dGrkJbssIkXp0cwH66O7I8THFZf0xXhqQTw2n8yUXRbVAQwqVGeZzQJzNp/G5K+PoKTMjIdb+ePbCZFo6Mn1TohuJcS3HtZN7I77mvmgqNSE8V8cxILt59jJlqoVgwrVSddKyhC96hA+2Xa90+yEXk2w6PlOXEGW6A60Lo5YFtUFIyIaQQhg1k+JePWbozCWsZMtVQ9FBBWj0Yj27dtDpVLhyJEjssuhWi5TX4ynF8Vj04lMONqrMGdIO7zxSCjs2GmWqFIc7O3wzsA2eOdGJ9tD6Ri2eC9yCtjJlqxPEUHl9ddfR/369WWXQXXA8TQ9Bs7biRPpBni5qrFqTDc8xU6zRHdlRERjLIvqAjcnBxy4eBUD5+3CmSzOZEvWJT2obNq0CT///DPmzJkjuxSq5TYdz8CQRbstM82un9gdXRqz0yzRvbi/uS/WTYxEI28XpF0twhPzd2P76cuyy6JaRGpQycrKwpgxY7By5Uq4uFSuA6PRaITBYCi3EVVECIF525Iw4ctDKC41o2dzX3w7MRJB3uw0S2QNTf3csH5id3QN9kKBsQwvLN+P5buS2cmWrEJaUBFCICoqCuPHj0fnzp0rfVxsbCy0Wq1l0+l01Vgl2TpjmQmvrj6K2ZtPAwCiIhtjycjOcOdMs0RW5emqxhejw/F054YwC+Dt709h2oYTKDWZZZdGNs7qQSUmJgYqlarCLTExEXPnzkV+fj6mTp1apfefOnUq9Hq9ZUtNTbX2KVAtkVNgxLDFe7H2cDrs7VR4d1AbvP14azjYS3/iSVQrqR3sMOvJMEztFwqVCvhiTwpeWL4f+qJS2aWRDVMJK9+by87ORk5OToVtQkJC8PTTT+P777+HSvXnSAuTyQR7e3sMGzYMK1asqNT3MxgM0Gq10Ov1cHd3v6faqfY4m5WPF1bsR2puEdycHDB/WEfc18xXdllEdcbPJzMx+esjuFZiQlO/elg6sgsft1I5lf38tnpQqayUlJRy/UsuXbqEvn37Ys2aNQgPD0fDhpUbicGgQn/3+9lsTPzyEPKLyxDk5YKlUZ3R1M9NdllEdc6JdD1eXHEAmYZieLmqsWh4J3ZgJwvFB5W/u3DhAoKDg3H48GG0b9++0scxqNBffbHnIqZ/dxIms0CXxp5YNLwzvFzVsssiqrOyDMV4ccUBHE/XQ21vh5lPtsUTHTklAFX+85sP66lWMJkFZnx/Em+uPwGTWeCJDg3wxYvhDClEkvm7O+Hrcd3wSOsAlJjMmLL6KD78+TTMZkX8PzLZAMXcUblbvKNCBcYyvPzVYfySeH3uhn/2aY7oB5qW6/9ERHKZzQKzfz6NBdvPAQAGhAVizpB2cHK0l1wZyVLZz28ubEI27VJeEV5Yvh+JmfnQONjhw6fbYUAYZzkmUho7OxXeeCQUwT6u+Nfa4/jhWAbS84rw6fDO8HXTyC6PFIyPfshmHUvLw8B5u5CYmQ+fehrEje3GkEKkcE931mHl6HBonR1xOCUPgzjtPt0BgwrZpJ9OZODpRfHIzjeihb8b1kdHokOQp+yyiKgSIpp4Y93ESDT2dkF6XhGenL8bO85kyy6LFIpBhWyKEAILd5zD+C/+nA5/zYQINPTk/AxEtiTEtx7WTeyO8GAv5P8x7f7KPRdll0UKxKBCNqPUZEbMt8cxc1MiAGBERCMsGdkZbpwOn8gmebqqsXJ0OJ7s2BAms8C09SfwzvenYOKIIPoLdqYlm6C/VoqJqw5iV1IO7FTAWwNaIap7sOyyiOgeqR3sMGdIGEJ8XTF782ks3ZWMlNxC/O/ZDnDV8COKeEeFbEBKzjU8sWAXdiXlwFVtj89GdmZIIapFVCoVoh9oik+e6wCNgx22JlzGkIXxyNAXyS6NFIBBhRTt4MVcDJq/C+eyCxGodcI34yPxYKi/7LKIqBoMCKuPr8Z2g089NU5lGDBo3i6cSNfLLoskkxpUGjdufNPKyjNnzpRZEinIhiPpGLp4L3ILS9C2gRYborujVX1O6kdUm3UM8sS6id3RzK8esgxGDFkYjy2nsmSXRRJJv6PyzjvvICMjw7K99NJLsksiyYQQ+L9fzuLluCMoKTOjTyt/fD2uG/zcnWSXRkQ1QOflgjUTInFfMx8UlZowduUBfPb7edj4ROp0l6QHFTc3NwQEBFg2V1dX2SWRRMYyE15dfRQfbTkDABh7fwgWPt8JLmp2qiOqS7TOjlga1QVDuwZBCOC9jQmYtuEEykxm2aVRDZO61k/jxo1RXFyM0tJSBAUF4bnnnsMrr7wCB4fbfygZjUYYjUbLa4PBAJ1Ox7V+aoGrhSUY98VB7EvOhb2dCu8MbI1h4Y1kl0VEEgkhsPj384jdlAghgJ7NffHJcx04LUEtYBOrJ0+aNAlxcXHYtm0bxo0bhw8++ACvv/56hcfExsZCq9VaNp1OV0PVUnVKvlKIJxbsxr7kXLhpHLAsqgtDChFBpVJh7P1NsGBYJzg52mHHmWwMWRiP9DyOCKorrH5HJSYmBrNmzaqwTUJCAkJDQ2/av3TpUowbNw4FBQXQaG69SBXvqNQ++5JzMXblAeRdK0UDD2csjeqCFgFusssiIoU5lpaH0SsOIDvfCF83DZaM7Iywhh6yy6K7VNk7KlYPKtnZ2cjJyamwTUhICNRq9U37T548iTZt2iAxMREtWrSo1Per7ImSMq07nIY31hxHicmMdg21WDyyM/zc2GmWiG4tPa8Io/9YMd3J0Q7/e7YD+rYOkF0W3YXKfn5bvYeir68vfH197+rYI0eOwM7ODn5+flauipRGCIGPt57F/345CwDo1yYAHz3dHs5qe8mVEZGSNfBwxjfjI/CPVYex40w2xn9xEP/q1xIv3hcMlUoluzyqBtKGUsTHx2Pv3r144IEH4Obmhvj4eLzyyit4/vnn4enJVXBrM2OZCTHfHse6w+kAgHE9Q/BG31DY2fGPDBHdmZuTI5aM7Iy3vz+JL/ak4P0fE5CcU4h3Hm8NB3vpg1nJyqQFFY1Gg7i4OLz99tswGo0IDg7GK6+8gilTpsgqiWrA1cISjFt5EPsuXB/Z896gNhjaNUh2WURkYxzs7fDuwDZo7O2K939MwKq9KUi7WoR5HBFU60gdnmwN7KNiO5KvFOKF5fuRfKUQbhoHzH++I+5rdnePCYmIbvj5ZCZejjuColITQgPcsCSqCxp4OMsui+7AJoYnU92xLzkXg+fvQvKVQjTwcMa3EyMZUojIKvq0DsDqcRHwddMgMTMfg+btwrG0PNllkZUwqFC1W3c4Dc9/thd510rRTueB9dHd0dyfw4+JyHraNtRifXR3hAa4ITvfiKcXxWPzyUzZZZEVMKhQtbk+sucMXvn6KEpMZvRrE4C4Md3g63brOXKIiO7FjRFBPZv7orjUjPFfHOQaQbUAgwpVC2OZCVNWH8XHW68PPx53fwjmPdeRw4+JqFrdGBH0fDeuEVRbcKU3srq8ayUYu/LPNXs4soeIatLfRwR9sScFqblFXCPIRvGOClnVhSuFGDz/zzV7lo/qwpBCRDVOpVLhxfuur77+1zWCLnGNIJvDoEJWs/9C+ZE9ayZwZA8RydX3FiOCjqfpZZdFVcCgQlax4Ug6hi3ei6vXShHWUIt10ZFcWJCIFCGs4fXRhi383XD5jxFBW05lyS6LKolBhe6JEAJzfzmLl+OOoMRkRt/W/vh6bAQXFiQiRWng4YxvJkTgvmY+KCo1YezKA1i2K1l2WVQJ0oPKxo0bER4eDmdnZ3h6emLQoEGyS6JKKikz45/fHMOHW84AAMbcF4z5wzpxZA8RKZK7kyOWRl3vNycEMOP7U5jOEUGKJ3XUz7fffosxY8bggw8+wIMPPoiysjKcOHFCZklUSfprpRj3xQHsOX99ZM+Mx1vj+W6NZJdFRFQhR3s7fDC4DYJ9XPDBj4lYEX8RqVeLMHdoB7hqOBBWiaSt9VNWVobGjRtjxowZGD16dKWPMxqNMBqNltcGgwE6nY5r/dSglJxriFq+D+ezC+GqtscnwzrigRZ+sssiIqqSTcczMPnrIzCWmdEq0B1Lo7ogQMvH1jVF8Wv9HDp0COnp6bCzs0OHDh0QGBiIfv363fGOSmxsLLRarWXT6XQ1VDEBwMGLVzF4/i6czy5EoNYJayZEMqQQkU3q1zYQcWO7waeeGqcyDBg0bxdOXTLILov+RlpQOX/+PADg7bffxptvvokffvgBnp6e6NWrF3Jzc2973NSpU6HX6y1bampqTZVc5208loGhi/cgp7AEbRq4Y310d7QM5F0sIrJdHYI8sW5idzTzq4dMQzGGLNyNbYmXZZdFf2H1oBITEwOVSlXhlpiYCLP5euelf//733jyySfRqVMnLFu2DCqVCt98881t31+j0cDd3b3cRtVLCIH525MQveoQSsrM6N3SD6vHRcDfnbdIicj26bxcsGZCJLo39UZhiQmjV+zHyvgLssuiP1i959Crr76KqKioCtuEhIQgIyMDANCqVSvLfo1Gg5CQEKSkpFi7LLpLpSYzpq0/gbj91+9cRUU2xrQBrWBvp5JcGRGR9WidHbEsqiveXH8cqw+kYdqGk7iYcw1TH23Jv3eSWT2o+Pr6wtf3zrORdurUCRqNBqdPn0aPHj0AAKWlpbhw4QIaNeLoESUwFJci+stD+P3sFdipgLcGtEJU92DZZRERVQu1gx1mPRmGRt6umL35ND7bmYyU3Gv4+Nn2cFFzRJAs0vqouLu7Y/z48Zg+fTp+/vlnnD59GhMmTAAADBkyRFZZ9Ie0q9fw1ILd+P3sFbio7bF4RGeGFCKq9VQqFaIfaIr/G9oBagc7/HwqC89+ugeX84tll1ZnSY2Is2fPhoODA4YPH46ioiKEh4fj119/haenp8yy6rxjaXkYveIAsvON8HPTYGlUF7RpoJVdFhFRjXm8XX3U1zphzOcHcCxNj8HzdmPZqC5o7s+lQWqatHlUrKWy47CpcjafzMTLcYdRXGpGaIAblkZ1QX0PZ9llERFJceFKIUYt34/kK4Vw0zhgwfOd0KOZj+yyagXFz6NCyiKEwGe/n8f4Lw6iuNSMns198c34CIYUIqrTGvu4Yu2ESHRt7IV8Yxmilu1D3D4O+KhJDCqEMpMZb204ifc2JkAIYFh4EJaM7Aw3J0fZpRERSefpqsbKF7tiUPv6KDMLxKw9jlk/JcJstukHEjaDQaWOKzCWYcznB7Byz0WoVMC/H22J9wa1gYM9/2kQEd2gcbDHf59pj0kPNQMALNh+Di/FHUZxqUlyZbUfP43qsEx9MZ5eGI9tp7OhcbDD/Oc6Ysz9IVCpOGcAEdHfqVQqTHm4OT4c0g6O9ipsPJaB5xbvQU6B8c4H011jUKmjTl36Y12LDAN86qnx9bgI9GsbKLssIiLFe7JTQ3z+QjjcnRxwKCUPg+fvxrnsAtll1VoMKnXQttOXMWThbmQaitHUrx7WTeyO9joP2WUREdmMiCbeWDuxO3RezkjJvYYn5u/GnvM5ssuqlRhU6piVey5i9PL9KCwxIbKJN76dEAmdl4vssoiIbM6N/9HrEOQBfVEphi/Zi7WH0mSXVeswqNQRJrPAez+cwrT1J2AWwJBODbF8VFdonTmyh4jobvnU0+CrMd3Qv20gSk0CU1YfxX+3nIGNT1GmKNKCyvbt22+7uvL+/ftllVUrFZWYMPHLg/hsZzIA4LW+LfCfp8KgdmBOJSK6V06O9pg7tAPG92wCAPjfL2fx6uqjMJZxRJA1SJuZtqSkBLm5ueX2TZs2Db/88gvOnTtX6ZEnnJm2YpfzizFmxQEcTdNDbW+H2UPCMLB9A9llERHVSl/tS8Gb60/AZBYID/bCouGd4OGill2WIil+Zlq1Wo2AgADL5u3tjQ0bNmDUqFEcHmslZ7LyMXjebhxN08PTxRFfjglnSCEiqkZDuwZhWVQX1NM4YG9yLp5YsBsXcwpll2XTFHPv/7vvvkNOTg5GjRpVYTuj0QiDwVBuo5vtSrqCJxfsRnpeEYJ9XLF2Ynd0aewluywiolrv/ua+WDMhAvW1TjifXYjB83fj4MWrssuyWYoJKkuWLEHfvn3RsGHDCtvFxsZCq9VaNp1OV0MV2o7V+1Mxcuk+5BeXoWtjL6ydEIlgH1fZZRER1RmhAe5YH90dbRtokVtYgqGL92DjsQzZZdkkqweVmJiY23aSvbElJiaWOyYtLQ2bN2/G6NGj7/j+U6dOhV6vt2ypqanWPgWbZTYLzN6ciNe/PYYys8DA9vWx8sWu8HTl81Eioprm5+6Er8d1Q++W/igpMyN61SEs2H6OI4KqyOqdabOzs5GTU/GkNyEhIVCr//zwfPfddzF37lykp6fD0bFqw2XZmfa64lITXltzDN8fvQQAmPRgU7zycHP29yEiksxkFnhv4yks23UBADC0qw7vDGwDxzq+plplP78drP2NfX194evrW+n2QggsW7YMI0aMqHJIoetyC0sw9vMDOHDxKhzsVIh9oi2GdOYjMSIiJbC3U2H6Y63RyMsF7/xwCl/tS0Xa1SLMG9YR7lyl/o6kx7lff/0VycnJePHFF2WXYpPOZxdg8PxdOHDxKtycHPD5C10ZUoiIFCiqezA+Hd4Zzo72+P3sFQxZEI/0vCLZZSme9KCyZMkSREZGIjQ0VHYpNmefZejbNTT0dMbaCZGIbOojuywiIrqN3q388c34CPi5aXA6Kx+D5u3C8TS97LIUTdqEb9ZSV/uobDiSjte+OYYSkxntdB74bERn+LppZJdFRESVcCmvCC8s34/EzHw4O9rj/4Z2wMOt/GWXVaMUP+Eb3R0hBOb+chYvxx1BicmMvq39ETemG0MKEZENqe/hjG/GR+D+5r4oKjVh7MoDWLYrWXZZisSgYkNKysx4bc0xfLjlDABgzH3BWDCsE5zV9pIrIyKiqnJzcsSSkZ0xtGsQhABmfH8Kb393EiazTT/osDoGFRuhLypF1LJ9WHMwDXYq4N1BbfDv/q1gZ8fhx0REtsrR3g4fDG6DmH7X+2ku330B41YewLWSMsmVKQeDig1Izb2GJxfsxu5zOXBV22PJyC4Y3q2R7LKIiMgKVCoVxvdsgnnPdYTawQ5bEy7j6UXxuGwoll2aIjCoKNyR1DwMnr8LSZcL4O+uwerxEXgg1E92WUREZGX9wwLx1Zhu8HZV40S6AYPm7UJiJtezY1BRsJ9OZOLZT+NxpaAELQOvrxvRur5WdllERFRNOjXyxLqJ3RHi64pL+mI8tSAev53Jll2WVAwqCiSEwOLfzmPClwdRXGrGAy188c34CARqnWWXRkRE1SzI2wVrJ0QiPNgLBcYyjFq+H1/tS5FdljQMKgpTZjJj2oYTeP/HBAgBPN8tCItHdEY9jdVXOyAiIoXycFFj5ehwPNGhAUxmgalrj2PmpkSY6+CIIAYVBSkwluHFzw/giz0pUKmAN/u3xLsD28Chji9cRURUF6kd7PDh0+0wuXczAMDCHefw0leHUVxqklxZzZL6CXjmzBkMHDgQPj4+cHd3R48ePbBt2zaZJUmTqS/GkIXx2H46G06OdlgwrBNevC+Eqx8TEdVhKpUKk3s3x3+faQdHexU2Hs/Ac4v3IKfAKLu0GiM1qAwYMABlZWX49ddfcfDgQbRr1w4DBgxAZmamzLJq3KlL13t3J2QY4FNPjbixEXikTYDssoiISCEGd2iIlaPDoXV2xKGUPAyevxvnsgtkl1UjpK31c+XKFfj6+uK3337DfffdBwDIz8+Hu7s7tmzZgt69e9/yOKPRCKPxzyRpMBig0+lsdq2fbYmX8Y9Vh1BYYkJTv3pYFtUFOi8X2WUREZECncsuwKhl+5GSew1aZ0csGt4J3UK8ZZd1VxS/1o+3tzdatGiBzz//HIWFhSgrK8OiRYvg5+eHTp063fa42NhYaLVay6bT6WqwautaueciRq/Yj8ISEyKbeOPbCZEMKUREdFtNfOth3cRIdAjygL6oFMOX7MW6w2myy6pWUldPTktLw6BBg3Do0CHY2dnBz88PGzduRIcOHW57TG24o2I2C8RuSsDi368vQPVUp4b4YHBbqB3YaZaIiO6suNSEV1cfxcbjGQCAyb2b4eWHmtlUv0Zpd1RiYmKgUqkq3BITEyGEQHR0NPz8/PD7779j3759GDRoEB577DFkZGTc9v01Gg3c3d3LbbakqMSECV8etISUf/ZpjtlPhTGkEBFRpTk52mPu0A4Y37MJAODjrWfx6uqjKCkzS67M+qx+RyU7Oxs5OTkVtgkJCcHvv/+OPn364OrVq+XCRrNmzTB69GjExMRU6vtVNpEpQXa+ES9+fgBHU/OgtrfD7CFhGNi+geyyiIjIhn21LwVvrj8Bk1mgW4gXFj3fGVoXR9ll3VFlP7+tPouYr68vfH1979ju2rVrAAA7u/J3Euzs7GA2175EeDYrH1HL9iM9rwgeLo74dHhndA32kl0WERHZuKFdg1DfwxnRXx7CnvO5GLxgF5ZHdUWQd+3o8yjteUNERAQ8PT0xcuRIHD16FGfOnMFrr72G5ORk9O/fX1ZZ1WJ30hU8sWA30vOK0PiPqZEZUoiIyFp6Nr+x1IoTzmcXYvD8XTiUclV2WVYhLaj4+Pjgp59+QkFBAR588EF07twZO3fuxIYNG9CuXTtZZVndNwdSMWLpPuQXl6FzI0+sndgdIb71ZJdFRES1zJ+L17ojp7AEQz/dgx+P377Pp62QOurHGpTaR0UIgY+2nMHcX5MAAAPCAjFnSDs4OdpLroyIiGqzQmMZXo47jK0JlwEAU/uFYuz9ypvpXPHzqNRmxjITJn99xBJSoh9ogv97tgNDChERVTtXjQMWDe+MqMjGAIDYTYn49/oTKDPZZv9PLslrZVcLSzBu5UHsu5ALBzsVPhjcFk93sd1J6YiIyPbY26nw9uOt0cjbBe/8cAqr9qYg/WoRPnmuA9yclD8i6K94R8WKLlwpxBMLdmPfhVy4aRywfFRXhhQiIpJmVPdgfDq8M5wd7bHjTDaGLIzHpbwi2WVVCYOKlRy4kIvB83ch+UohGng449uJkejRzEd2WUREVMc93Mofq8dFwNdNg8TMfAyevwsn0vWyy6o0BhUr+P7oJTz32V5cvVaKsIZarIuORHN/N9llERERAQDaNtRi3cRINPevhyyDEU8viscvCVmyy6oUBpV7IITAvG1JeOmrwygpM+PhVv6IG9sNfm5OsksjIiIqp6GnC9ZMiMR9zXxwrcSEMZ8fwIrdF2SXdUcMKnep1GRGzLfHMXvzaQDA6B7BWPh8J7io2T+ZiIiUyd3JEUujuuDZLjqYBTD9u5N45/tTMJmVO1MJP1Xvgr6oFBO/PIhdSTmwUwHTH2uNkX8MAyMiIlIyR3s7xD7RFkHeLvjPT6exdFcyUq9ew/+eba/I/9mWekfl0KFDePjhh+Hh4QFvb2+MHTsWBQUFMku6o7Sr1zBk4W7sSsqBi9oei0d0ZkghIiKbolKpMLFXU8wd2gFqBztsOZWFZz/dg8v5xbJLu4m0oHLp0iX07t0bTZs2xd69e/HTTz/h5MmTiIqKklXSHR1Ly8Pg+btxJqsAfm4arB4XgYda+ssui4iI6K481q4+Vr0YDk8XRxxL02PwvN04k5Uvu6xypE2h/+mnn2LatGnIyMiwrKB8/PhxhIWF4ezZs2jatGml3qemptD/+WQmJsUdRnGpGaEBblga1QX1PZyr7fsRERHVlAtXCvHC8v04f6UQbhoHLHi+U7VPsaH4KfSNRiPUarUlpACAs/P1D/6dO3dWeJzBYCi3VSchBJbsTMa4Lw6iuNSM+/9YoZIhhYiIaovGPq74dkIkujb2Qr6xDFHL9mH1/lTZZQGQGFQefPBBZGZmYvbs2SgpKcHVq1cRExMDAMjIuP1qj7GxsdBqtZZNp6u+mV9NZoG3vzuJd384BSGAoV2DsGRkZ5ubfpiIiOhOPF3VWPliVwxsXx9lZoHXvz2G2ZsTYZY8IsjqQSUmJgYqlarCLTExEa1bt8aKFSvw4YcfwsXFBQEBAQgODoa/v3+5uyx/N3XqVOj1esuWmlo9ia/QWIaxnx/AiviL179vv1B8MLgNHO05opuIiGonjYM9Pn6mPSY9eL37xbxt5/Dy10dQXGqSVpPV+6hkZ2cjJyenwjYhISFQq9WW11lZWXB1dYVKpYK7uzvi4uIwZMiQSn2/6uqj8uKK/diacBkaBzv895n2eLRtoNXem4iISOm+OZCKqWuPo8ws8ESHBvjomfZWff/Kfn5bfcC0r68vfH19q3SMv//1kTNLly6Fk5MTHn74YWuXVWVTHm6BM1kF+PjZ9ugY5Cm7HCIioho1pLMODTycEbP2OP7xYOUGuFQHaaN+AOCTTz5BZGQk6tWrhy1btuC1117DzJkzMWnSpEq/R3WO+ik1mfmoh4iI6rTq+iyUdkelKvbt24fp06ejoKAAoaGhWLRoEYYPHy6zpHIYUoiIqK6T/VkoNah8/vnnMr89ERERKRxvGRAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWJVW1B5//33ERkZCRcXF3h4eNyyTUpKCvr37w8XFxf4+fnhtddeQ1lZWXWVRERERDam2qbQLykpwZAhQxAREYElS5bc9HWTyYT+/fsjICAAu3fvRkZGBkaMGAFHR0d88MEH1VUWERER2ZBqXz15+fLlmDx5MvLy8srt37RpEwYMGIBLly7B398fALBw4UK88cYbyM7OhlqtvuX7GY1GGI1Gy2u9Xo+goCCkpqZaffVkIiIiqh4GgwE6nQ55eXnQarW3bSdtUcL4+Hi0bdvWElIAoG/fvpgwYQJOnjyJDh063PK42NhYzJgx46b9Op2u2molIiKi6pGfn6/MoJKZmVkupACwvM7MzLztcVOnTsWUKVMsr81mM3Jzc+Ht7Q2VSmXVGm+kvdp6t4bnZ/tq+zny/GxfbT9Hnt/dE0IgPz8f9evXr7BdlYJKTEwMZs2aVWGbhIQEhIaGVuVtq0Sj0UCj0ZTbd7vOutbi7u5eK/8B3sDzs321/Rx5fravtp8jz+/uVHQn5YYqBZVXX30VUVFRFbYJCQmp1HsFBARg37595fZlZWVZvkZERERUpaDi6+sLX19fq3zjiIgIvP/++7h8+TL8/PwAAFu2bIG7uztatWplle9BREREtq3a+qikpKQgNzcXKSkpMJlMOHLkCACgadOmqFevHvr06YNWrVph+PDh+M9//oPMzEy8+eabiI6OvunRjiwajQbTp09XTD3WxvOzfbX9HHl+tq+2nyPPr/pV2/DkqKgorFix4qb927ZtQ69evQAAFy9exIQJE7B9+3a4urpi5MiRmDlzJhwcpPXxJSIiIgWp9nlUiIiIiO4W1/ohIiIixWJQISIiIsViUCEiIiLFYlAhIiIixarTQeX9999HZGQkXFxcKj27rRACb731FgIDA+Hs7IzevXvj7Nmz5drk5uZi2LBhcHd3h4eHB0aPHo2CgoJqOIOKVbWOCxcuQKVS3XL75ptvLO1u9fW4uLiaOKWb3M3PulevXjfVP378+HJtUlJS0L9/f7i4uMDPzw+vvfYaysrKqvNUbqmq55ebm4uXXnoJLVq0gLOzM4KCgjBp0iTo9fpy7WRew3nz5qFx48ZwcnJCeHj4TRM//t0333yD0NBQODk5oW3btvjxxx/Lfb0yv5M1qSrnt3jxYtx3333w9PSEp6cnevfufVP7qKiom67VI488Ut2ncVtVOb/ly5ffVLuTk1O5Nkq7fkDVzvFWf09UKhX69+9vaaOka/jbb7/hscceQ/369aFSqbB+/fo7HrN9+3Z07NgRGo0GTZs2xfLly29qU9Xf6yoRddhbb70lPvroIzFlyhSh1WordczMmTOFVqsV69evF0ePHhWPP/64CA4OFkVFRZY2jzzyiGjXrp3Ys2eP+P3330XTpk3F0KFDq+ksbq+qdZSVlYmMjIxy24wZM0S9evVEfn6+pR0AsWzZsnLt/nr+NeluftY9e/YUY8aMKVe/Xq+3fL2srEy0adNG9O7dWxw+fFj8+OOPwsfHR0ydOrW6T+cmVT2/48ePiyeeeEJ89913IikpSfzyyy+iWbNm4sknnyzXTtY1jIuLE2q1WixdulScPHlSjBkzRnh4eIisrKxbtt+1a5ewt7cX//nPf8SpU6fEm2++KRwdHcXx48ctbSrzO1lTqnp+zz33nJg3b544fPiwSEhIEFFRUUKr1Yq0tDRLm5EjR4pHHnmk3LXKzc2tqVMqp6rnt2zZMuHu7l6u9szMzHJtlHT9hKj6Oebk5JQ7vxMnTgh7e3uxbNkySxslXcMff/xR/Pvf/xZr164VAMS6desqbH/+/Hnh4uIipkyZIk6dOiXmzp0r7O3txU8//WRpU9WfWVXV6aByw7JlyyoVVMxmswgICBCzZ8+27MvLyxMajUZ89dVXQgghTp06JQCI/fv3W9ps2rRJqFQqkZ6ebvXab8dadbRv31688MIL5fZV5h93Tbjbc+zZs6d4+eWXb/v1H3/8UdjZ2ZX7g7pgwQLh7u4ujEajVWqvDGtdw9WrVwu1Wi1KS0st+2Rdw65du4ro6GjLa5PJJOrXry9iY2Nv2f7pp58W/fv3L7cvPDxcjBs3TghRud/JmlTV8/u7srIy4ebmJlasWGHZN3LkSDFw4EBrl3pXqnp+d/rbqrTrJ8S9X8P//ve/ws3NTRQUFFj2Keka/lVl/g68/vrronXr1uX2PfPMM6Jv376W1/f6M7uTOv3op6qSk5ORmZmJ3r17W/ZptVqEh4cjPj4eABAfHw8PDw907tzZ0qZ3796ws7PD3r17a6xWa9Rx8OBBHDlyBKNHj77pa9HR0fDx8UHXrl2xdOlSCAnT8dzLOX755Zfw8fFBmzZtMHXqVFy7dq3c+7Zt27bc6t59+/aFwWDAyZMnrX8it2Gtf0t6vR7u7u43TaRY09ewpKQEBw8eLPf7Y2dnh969e1t+f/4uPj6+XHvg+rW40b4yv5M15W7O7++uXbuG0tJSeHl5ldu/fft2+Pn5oUWLFpgwYQJycnKsWntl3O35FRQUoFGjRtDpdBg4cGC53yElXT/AOtdwyZIlePbZZ+Hq6lpuvxKu4d240++gNX5md8IpYKsgMzMTAMp9gN14feNrmZmZlrWLbnBwcICXl5elTU2wRh1LlixBy5YtERkZWW7/O++8gwcffBAuLi74+eefMXHiRBQUFGDSpElWq78y7vYcn3vuOTRq1Aj169fHsWPH8MYbb+D06dNYu3at5X1vdY1vfK2mWOMaXrlyBe+++y7Gjh1bbr+Ma3jlyhWYTKZb/mwTExNvecztrsVff99u7Ltdm5pyN+f3d2+88Qbq169f7o/+I488gieeeALBwcE4d+4c/vWvf6Ffv36Ij4+Hvb29Vc+hIndzfi1atMDSpUsRFhYGvV6POXPmIDIyEidPnkTDhg0Vdf2Ae7+G+/btw4kTJ7BkyZJy+5VyDe/G7X4HDQYDioqKcPXq1Xv+d38ntS6oxMTEYNasWRW2SUhIQGhoaA1VZF2VPb97VVRUhFWrVmHatGk3fe2v+zp06IDCwkLMnj3bah9y1X2Of/3Qbtu2LQIDA/HQQw/h3LlzaNKkyV2/b2XV1DU0GAzo378/WrVqhbfffrvc16r7GlLVzZw5E3Fxcdi+fXu5DqfPPvus5b/btm2LsLAwNGnSBNu3b8dDDz0ko9RKi4iIQEREhOV1ZGQkWrZsiUWLFuHdd9+VWFn1WLJkCdq2bYuuXbuW22/L11AJal1QefXVVxEVFVVhm5CQkLt674CAAABAVlYWAgMDLfuzsrLQvn17S5vLly+XO66srAy5ubmW4+9FZc/vXutYs2YNrl27hhEjRtyxbXh4ON59910YjUarLFxVU+d4Q3h4OAAgKSkJTZo0QUBAwE091rOysgDAZq5hfn4+HnnkEbi5uWHdunVwdHSssL21r+Gt+Pj4wN7e3vKzvCErK+u25xMQEFBh+8r8TtaUuzm/G+bMmYOZM2di69atCAsLq7BtSEgIfHx8kJSUVKMfcvdyfjc4OjqiQ4cOSEpKAqCs6wfc2zkWFhYiLi4O77zzzh2/j6xreDdu9zvo7u4OZ2dn2Nvb3/O/izuySk8XG1fVzrRz5syx7NPr9bfsTHvgwAFLm82bN0vrTHu3dfTs2fOmkSK389577wlPT8+7rvVuWetnvXPnTgFAHD16VAjxZ2fav/ZYX7RokXB3dxfFxcXWO4E7uNvz0+v1olu3bqJnz56isLCwUt+rpq5h165dxT/+8Q/La5PJJBo0aFBhZ9oBAwaU2xcREXFTZ9qKfidrUlXPTwghZs2aJdzd3UV8fHylvkdqaqpQqVRiw4YN91xvVd3N+f1VWVmZaNGihXjllVeEEMq7fkLc/TkuW7ZMaDQaceXKlTt+D5nX8K9Qyc60bdq0Kbdv6NChN3WmvZd/F3es0yrvYqMuXrwoDh8+bBmCe/jwYXH48OFyQ3FbtGgh1q5da3k9c+ZM4eHhITZs2CCOHTsmBg4ceMvhyR06dBB79+4VO3fuFM2aNZM2PLmiOtLS0kSLFi3E3r17yx139uxZoVKpxKZNm256z++++04sXrxYHD9+XJw9e1bMnz9fuLi4iLfeeqvaz+dWqnqOSUlJ4p133hEHDhwQycnJYsOGDSIkJETcf//9lmNuDE/u06ePOHLkiPjpp5+Er6+vtOHJVTk/vV4vwsPDRdu2bUVSUlK54ZBlZWVCCLnXMC4uTmg0GrF8+XJx6tQpMXbsWOHh4WEZYTV8+HARExNjab9r1y7h4OAg5syZIxISEsT06dNvOTz5Tr+TNaWq5zdz5kyhVqvFmjVryl2rG3+D8vPzxT//+U8RHx8vkpOTxdatW0XHjh1Fs2bNajQ03+35zZgxQ2zevFmcO3dOHDx4UDz77LPCyclJnDx50tJGSddPiKqf4w09evQQzzzzzE37lXYN8/PzLZ91AMRHH30kDh8+LC5evCiEECImJkYMHz7c0v7G8OTXXntNJCQkiHnz5t1yeHJFP7N7VaeDysiRIwWAm7Zt27ZZ2uCP+SZuMJvNYtq0acLf319oNBrx0EMPidOnT5d735ycHDF06FBRr1494e7uLkaNGlUu/NSUO9WRnJx80/kKIcTUqVOFTqcTJpPppvfctGmTaN++vahXr55wdXUV7dq1EwsXLrxl25pQ1XNMSUkR999/v/Dy8hIajUY0bdpUvPbaa+XmURFCiAsXLoh+/foJZ2dn4ePjI1599dVyw3trSlXPb9u2bbf8Nw1AJCcnCyHkX8O5c+eKoKAgoVarRdeuXcWePXssX+vZs6cYOXJkufarV68WzZs3F2q1WrRu3Vps3Lix3Ncr8ztZk6pyfo0aNbrltZo+fboQQohr166JPn36CF9fX+Ho6CgaNWokxowZY7UPgLtRlfObPHmypa2/v7949NFHxaFDh8q9n9KunxBV/zeamJgoAIiff/75pvdS2jW83d+IG+c0cuRI0bNnz5uOad++vVCr1SIkJKTcZ+INFf3M7pVKCAnjSomIiIgqgfOoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFi/T8qIITBChX6pQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# log(1/(1 + exp(-x - eps) - 1/(1 + exp(-x + eps)))\n",
    "# plot \n",
    "inv_scale = 10\n",
    "f = lambda x : np.log(1/(1 + np.exp(- inv_scale * (x  + 0.1))) - 1/(1 + np.exp(- inv_scale* (x  - 0.1))))\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "x = np.linspace(-1, 1, 100)\n",
    "y = f(x)\n",
    "plt.plot(x, y)\n",
    "plt.yticks(np.arange(-10, 10, 1))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100,\n",
       "        0.9100])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_classes = 10\n",
    "beta = 0.1\n",
    "mat = torch.ones(num_classes, num_classes) * beta / num_classes\n",
    "# diagonal is 1 - (K - 1)/K * beta\n",
    "mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100,\n",
       "         0.0100],\n",
       "        [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
       "         0.9100]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 1, 32, 32, 16])\n"
     ]
    }
   ],
   "source": [
    "from typing import Dict, Tuple\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image, make_grid\n",
    "\n",
    "\n",
    "blk = lambda ic, oc: nn.Sequential(\n",
    "    nn.Conv2d(ic, oc, 7, padding=3),\n",
    "    nn.GroupNorm(oc // 8, oc, affine = False, eps = 1e-4),\n",
    "    nn.LeakyReLU(),\n",
    ")\n",
    "\n",
    "\n",
    "class DummyX0Model(nn.Module):\n",
    "    \"\"\"\n",
    "    This should be unet-like, but let's don't think about the model too much :P\n",
    "    Basically, any universal R^n -> R^n model should work.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, n_channel: int, N) -> None:\n",
    "        super(DummyX0Model, self).__init__()\n",
    "        self.start =  blk(n_channel, 16)\n",
    "        self.pe = nn.Parameter(torch.randn(1, 16, 32, 32))\n",
    "        self.conv = nn.Sequential(\n",
    "            blk(16, 128),\n",
    "            blk(128, 256),\n",
    "            blk(256, 512),\n",
    "            blk(512, 256),\n",
    "            blk(256, 128),\n",
    "            blk(128, 64),\n",
    "            nn.Conv2d(64, 2 * n_channel, 3, padding=1),\n",
    "        )\n",
    "        self.N = N\n",
    "\n",
    "    def forward(self, x, t) -> torch.Tensor:\n",
    "        # Lets think about using t later. In the paper, they used Tr-like positional embeddings.\n",
    "        st = x.float() / self.N - 0.5\n",
    "        x = self.start(st) + self.pe\n",
    "        y = self.conv(x)\n",
    "        loc, log_scale = y.chunk(2, dim=1)\n",
    "\n",
    "        return torch.tanh(loc + st), log_scale\n",
    "\n",
    "blk = lambda ic, oc: nn.Sequential(\n",
    "    nn.Conv2d(ic, oc, 5, padding=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    "    nn.Conv2d(oc, oc, 5, padding=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    "    nn.Conv2d(oc, oc, 5, padding=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    ")\n",
    "\n",
    "blku = lambda ic, oc: nn.Sequential(\n",
    "    nn.Conv2d(ic, oc, 5, padding=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    "    nn.Conv2d(oc, oc, 5, padding=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    "    nn.Conv2d(oc, oc, 5, padding=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    "    nn.ConvTranspose2d(oc, oc, 2, stride=2),\n",
    "    nn.GroupNorm(oc // 8, oc),\n",
    "    nn.LeakyReLU(),\n",
    "    \n",
    ")\n",
    "\n",
    "class DummyX0Model(nn.Module):\n",
    "    \n",
    "\n",
    "    def __init__(self, n_channel: int, N : int = 16) -> None:\n",
    "        super(DummyX0Model, self).__init__()\n",
    "        self.down1 = blk(n_channel, 16)\n",
    "        self.down2 = blk(16, 32)\n",
    "        self.down3 = blk(32, 64)\n",
    "        self.down4 = blk(64, 512)\n",
    "        self.down5 = blk(512, 512)\n",
    "        self.up1 = blku(512, 512)\n",
    "        self.up2 = blku(512 + 512, 64)  # Corrected to account for concatenated feature maps\n",
    "        self.up3 = blku(64 + 64, 32)    # Corrected to account for concatenated feature maps\n",
    "        self.up4 = blku(32 + 32, 16)    # Corrected to account for concatenated feature maps\n",
    "        self.convlast = blk(32, 16)\n",
    "        self.final = nn.Conv2d(16, N * n_channel, 1, bias = False)\n",
    "\n",
    "        # initialize final with zero\n",
    "        #self.final.weight.data.zero_()\n",
    "        #self.final.bias.data.zero_()\n",
    "\n",
    "        self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n",
    "        self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n",
    "        self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8)\n",
    "\n",
    "        self.temb_1 = nn.Linear(32, 16)\n",
    "        self.temb_2 = nn.Linear(32, 32)\n",
    "        self.temb_3 = nn.Linear(32, 64)\n",
    "        self.temb_4 = nn.Linear(32, 512)\n",
    "        self.N = N\n",
    "\n",
    "    def forward(self, x, t) -> torch.Tensor:\n",
    "        x = (2 * x.float() / self.N) - 1.0\n",
    "        t = t.float().reshape(-1, 1) / 500\n",
    "        t_as_sin = [torch.sin(t * 3.1415 * 2 ** i) for i in range(16)]\n",
    "        t_as_cos = [torch.cos(t * 3.1415 * 2 ** i) for i in range(16)]\n",
    "        # concat and send it to t_emb\n",
    "        t_emb_1 = self.temb_1(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
    "        t_emb_2 = self.temb_2(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
    "        t_emb_3 = self.temb_3(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
    "        t_emb_4 = self.temb_4(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
    "        \n",
    "        x1 = self.down1(x) + t_emb_1\n",
    "        x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2\n",
    "        x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3\n",
    "        x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4\n",
    "        x5 = self.down5(nn.functional.avg_pool2d(x4, 2))\n",
    "\n",
    "        x5 = self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(x5.shape)\n",
    "\n",
    "        y = self.up1(x5)\n",
    "\n",
    "        y = self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(y.shape)\n",
    "  \n",
    "        y = self.up2(torch.cat([x4, y], dim=1))\n",
    "\n",
    "        y = self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(y.shape)\n",
    "\n",
    "        y = self.up3(torch.cat([x3, y], dim=1))\n",
    "\n",
    "        y = self.up4(torch.cat([x2, y], dim=1))\n",
    "\n",
    "\n",
    "        y = self.convlast(torch.cat([x1, y], dim=1))\n",
    "        y = self.final(y)\n",
    "        # reshape to B, C, H, W, N\n",
    "        y = y.reshape(y.shape[0], -1, self.N, *x.shape[2:]).transpose(2, -1).contiguous()\n",
    "      \n",
    "        return y\n",
    "\n",
    "\n",
    "# check if it takes 2, 1, 28, 28 input\n",
    "\n",
    "model = DummyX0Model(1, N = 16)\n",
    "print(model(torch.randn(2, 1, 32, 32), torch.tensor([1, 2])).shape)\n",
    "\n",
    "\n",
    "\n",
    "def get_logits_from_logistic_pars(loc, log_scale, num_classes = 10):\n",
    "    loc = loc.unsqueeze(-1)\n",
    "    log_scale = log_scale.unsqueeze(-1)\n",
    "    inv_scale = (-log_scale + 2.0).exp()\n",
    "\n",
    "    bin_width = 2.0 / (num_classes - 1)\n",
    "    bin_centers = torch.linspace(-1.0, 1.0, num_classes).to(loc.device)\n",
    "    bin_centers = bin_centers.reshape([1] * (len(loc.shape) - 1) + [num_classes])\n",
    "    bin_centers = bin_centers - loc\n",
    "    log_cdf_min = -torch.log1p((-inv_scale * (bin_centers - 0.5 * bin_width)).exp())\n",
    "    log_cdf_plus = -torch.log1p((-inv_scale * (bin_centers + 0.5 * bin_width)).exp())\n",
    "    logits = log_minus_exp(log_cdf_plus, log_cdf_min)\n",
    "    return logits\n",
    "\n",
    "\n",
    "def log_minus_exp(a, b, epsilon=1e-6):\n",
    "    return a + torch.log1p(-torch.exp(b - a) + epsilon)\n",
    "\n",
    "\n",
    "\n",
    "class D3PM(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        x0_model: nn.Module,\n",
    "        n_T: int,\n",
    "        num_classes: int = 10,\n",
    "        forward_type = 'uniform',\n",
    "        hybrid_loss_coeff = 0.001\n",
    "    ) -> None:\n",
    "        super(D3PM, self).__init__()\n",
    "        self.x0_model = x0_model\n",
    "\n",
    "        self.n_T = n_T\n",
    "        self.hybrid_loss_coeff = hybrid_loss_coeff\n",
    "        self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)]\n",
    "        self.eps = 1e-8\n",
    "        self.num_classses = num_classes\n",
    "        q_onestep_mats = []\n",
    "        q_mats = [] # these are cumulative\n",
    "\n",
    "        for beta in self.beta_t:\n",
    "\n",
    "            if forward_type == 'uniform':\n",
    "                mat = torch.ones(num_classes, num_classes) * beta / num_classes\n",
    "                mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)\n",
    "                q_onestep_mats.append(mat)\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "        q_one_step_mats = torch.stack(q_onestep_mats, dim=0)\n",
    "\n",
    "        q_one_step_transposed = q_one_step_mats.transpose(1, 2) # this will be used for q_posterior_logits\n",
    "\n",
    "        q_mat_t = q_onestep_mats[0]\n",
    "        q_mats = [q_mat_t]\n",
    "        for idx in range(1, self.n_T):\n",
    "            q_mat_t = q_mat_t @ q_onestep_mats[idx]\n",
    "            q_mats.append(q_mat_t)\n",
    "        q_mats = torch.stack(q_mats, dim=0)\n",
    "        self.logit_type = 'logit'\n",
    "\n",
    "        # register\n",
    "        self.register_buffer(\"q_one_step_transposed\", q_one_step_transposed)\n",
    "        self.register_buffer(\"q_mats\", q_mats)\n",
    "\n",
    "        assert self.q_mats.shape == (self.n_T, num_classes, num_classes), self.q_mats.shape\n",
    "    \n",
    "    def _at(self, a, t, x):\n",
    "        # t is 1-d, x is integer value of 0 to num_classes - 1\n",
    "        bs = t.shape[0]\n",
    "        t = t.reshape((bs, *[1] * (x.dim() - 1)))\n",
    "        #out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]\n",
    "        return a[t - 1, x, :]\n",
    "\n",
    "\n",
    "    def q_posterior_logits(self, x_0, x_t, t):\n",
    "        # if t == 1, this means we return the L_0 loss, so directly try to x_0 logits.\n",
    "        # otherwise, we return the L_{t-1} loss.\n",
    "        # Also, we never have t == 0.\n",
    "\n",
    "        # if x_0 is integer, we convert it to one-hot.\n",
    "        if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:\n",
    "            x_0_logits = torch.log(torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps)\n",
    "        else:\n",
    "            x_0_logits = x_0.clone()\n",
    "\n",
    "        assert x_0_logits.shape == x_t.shape + (self.num_classses,), print(f\"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}\")\n",
    "\n",
    "        # Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that.\n",
    "        # fact1 is \"guess of x_{t-1}\" from x_t\n",
    "        # fact2 is \"guess of x_{t-1}\" from x_0\n",
    "\n",
    "        fact1 = self._at(self.q_one_step_transposed, t, x_t)\n",
    "        #fact2 = self._at_onehot(self.q_mats, t-1, )\n",
    "        # x, a[t-1]\n",
    "\n",
    "\n",
    "        softmaxed = torch.softmax(x_0_logits, dim=-1) # bs, ..., num_classes\n",
    "        qmats2 = self.q_mats[t-2] # bs, num_classes, num_classes\n",
    "\n",
    "        fact2 = torch.einsum('b...c,bcd->b...d', softmaxed, qmats2)\n",
    "    \n",
    "        #print(f\"Fact1Fact2\", fact1.shape, fact2.shape)\n",
    "        out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)\n",
    "        #print(f\"out: {out.shape}\")\n",
    "        t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))\n",
    "        #print(t_broadcast.shape)\n",
    "        bc = torch.where(t_broadcast == 1, x_0_logits, out)\n",
    "        #print(f\"bc: {bc.shape}\")\n",
    "        return bc\n",
    "\n",
    "    def vb(self, dist1, dist2):\n",
    "        out = (torch.softmax(dist1 + self.eps, dim = -1)*(torch.log_softmax(dist1 + self.eps, dim = -1) - torch.log_softmax(dist2 + self.eps, dim = -1)))\n",
    "        return out.sum(dim=-1).mean()\n",
    "\n",
    "        \n",
    "    def q_sample(self, x_0, t, noise):\n",
    "        # forward process, x_0 is the clean input.\n",
    "        \n",
    "        logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)\n",
    "        noise = torch.clip(noise, self.eps, 1.0)\n",
    "        gumbel_noise = -torch.log(-torch.log(noise))\n",
    "        return torch.argmax(logits + gumbel_noise, dim=-1)\n",
    "\n",
    "    def model_predict(self, x_0, t):\n",
    "        if self.logit_type == 'logit':\n",
    "            predicted_x0_logits = self.x0_model(x_0, t)\n",
    "        else:\n",
    "            loc, log_scale = self.x0_model(x_0, t)\n",
    "            predicted_x0_logits = get_logits_from_logistic_pars(loc, log_scale, self.num_classses)\n",
    "        return predicted_x0_logits\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model.\n",
    "        x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1\n",
    "        \"\"\"\n",
    "        t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)\n",
    "        x_t = self.q_sample(x, t, torch.rand((*x.shape, self.num_classses), device=x.device))\n",
    "        # x_t is same shape as x\n",
    "        assert x_t.shape == x.shape, print(f\"x_t.shape: {x_t.shape}, x.shape: {x.shape}\")\n",
    "        # we use hybrid loss.\n",
    "        \n",
    "        predicted_x0_logits = self.model_predict(x, t)\n",
    "\n",
    "\n",
    "        # based on this, we first do vb loss.\n",
    "        true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)\n",
    "        #print(f\"predicted_x0_logits: {predicted_x0_logits.shape}, true_q_posterior_logits: {true_q_posterior_logits.shape}\")\n",
    "        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)\n",
    "\n",
    "        vb_loss = self.vb(true_q_posterior_logits,pred_q_posterior_logits)\n",
    "\n",
    "\n",
    "\n",
    "        predicted_x0_logits = predicted_x0_logits.flatten(start_dim = 0, end_dim = -2)\n",
    "        x = x.flatten(start_dim = 0, end_dim = -1)\n",
    "        #print(f\"predicted_x0_logits: {predicted_x0_logits.shape}, x: {x.shape}\")\n",
    "\n",
    "\n",
    "        ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)\n",
    "\n",
    "        return vb_loss + ce_loss*self.hybrid_loss_coeff, {\"vb_loss\": vb_loss.detach().item(), \"ce_loss\": ce_loss.detach().item()}\n",
    "\n",
    "    def p_sample(self, x, t, noise):\n",
    "        \n",
    "        predicted_x0_logits = self.model_predict(x, t)\n",
    "        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)\n",
    "\n",
    "        noise = torch.clip(noise, self.eps, 1.0)\n",
    "\n",
    "        not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))\n",
    "        gumbel_noise = -torch.log(-torch.log(noise))\n",
    "        sample = torch.argmax(pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1)\n",
    "        return sample\n",
    "\n",
    "    def sample(self, x = None):\n",
    "        for t in reversed(range(1, self.n_T)):\n",
    "            t = torch.tensor([t]*x.shape[0], device=x.device)\n",
    "            x = self.p_sample(x, t, torch.rand((*x.shape, self.num_classses), device=x.device))\n",
    "     \n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#d3pm = D3PM(DummyX0Model(1, 20), 1000, num_classes = 20).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#d3pm.sample(torch.randint(0, 2, (2, 1, 32, 32)).cuda())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# d3pm.train()\n",
    "# d3pm.forward(torch.randint(0, 2, (128, 1, 28, 28)).cuda())\n",
    "# dataset = MNIST(\n",
    "#         \"./data\",\n",
    "#         train=True,\n",
    "#         download=True,\n",
    "#         transform=transforms.ToTensor()\n",
    "#     )\n",
    "# dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)\n",
    "# d3pm(x = next(iter(dataloader))[0].cuda().long())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Param Count: 63336704\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 30, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
      "  warnings.warn(_create_warning_msg(\n"
     ]
    }
   ],
   "source": [
    "N = 16\n",
    "d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes = N, hybrid_loss_coeff=0.0).cuda()\n",
    "print(f\"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}\")\n",
    "dataset = MNIST(\n",
    "        \"./data\",\n",
    "        train=True,\n",
    "        download=True,\n",
    "        transform= transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Pad(2),\n",
    "        ])\n",
    "    )\n",
    "dataloader = DataLoader([dataset[0]] * 50000, batch_size=32, shuffle=True, num_workers=32)\n",
    "#dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32)\n",
    "\n",
    "optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=4e-4, betas=(0.95, 0.99))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1563 [00:00<?, ?it/s]/root/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 30, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
      "  warnings.warn(_create_warning_msg(\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0.0025, norm: 0.0052, param_norm: 923.3715, vb_loss: 0.0025, ce_loss: 2.5743:   0%|          | 0/1563 [00:02<?, ?it/s]"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl50lEQVR4nO3dfXAV1fkH8G9CyA0KuSk4JKYQklZHfC00CEadKpoWqaOiTKsOrcE6ddTEGpmpilbbsUOD44xiO4jTjk3sVIqlI7G+oGMDYrHhLSW+lBpx5G3UG2qd5MYXAianf/jj/s4ecp/ds7t3703y/cxkJje7d/fcs7uXw3nOeU6eUkqBiIiIKCL52S4AERERjS5sfBAREVGk2PggIiKiSLHxQURERJFi44OIiIgixcYHERERRYqNDyIiIooUGx9EREQUKTY+iIiIKFJsfBAREVGkMtb4WLlyJSorK1FUVIQ5c+Zg27ZtmToVERERDSN5mVjb5amnnsJ1112Hxx57DHPmzMGKFSuwdu1adHV1YfLkyeJ7BwcH8cEHH2DChAnIy8sLu2hERESUAUop9PX1oby8HPn5Ln0bKgNmz56t6uvrU68HBgZUeXm5ampqcn3vgQMHFAD+8Ic//OEPf/gzDH8OHDjg+m996GGXw4cPo6OjA7W1tam/5efno7a2Fu3t7cfs39/fj2QymfpRXGSXiIho2JowYYLrPqE3Pj766CMMDAygtLTU8ffS0lIkEolj9m9qakI8Hk/9VFRUhF0kIiIiioiXIRNZn+2ydOlS9Pb2pn4OHDiQ7SIRERFRBhWEfcATTjgBY8aMQXd3t+Pv3d3dKCsrO2b/WCyGWCwWdjGIiIgoR4Xe81FYWIjq6mq0tbWl/jY4OIi2tjbU1NSEfToiIiIaZkLv+QCAJUuWoK6uDrNmzcLs2bOxYsUKfPrpp7j++usDH3vdunUhlBCorKx0vN67d28ox73yyitTv+/cuTMj5xhOenp6HK9LSkpSv+t1ZQpynfVra1PnjY2NjtcrVqzwXQb6f9J1BsJ7pv165ZVXHK8vvPDCSM+fqe+iKJjPzL59+9Lum+3rbCOsa6J//+nffcOd2zPtRUYaH1dffTX+85//4L777kMikcCMGTPw4osvHjMIlYiIiEafjDQ+AKChoQENDQ2ZOjwRERENU1mf7UJERESjS8Z6PjJl8eLFjtctLS1p9zVjufp7zRievs18n7mvNK5Dj2uasWP9uDNnzhyyzEMdx4yvxePxtO/r7e11vDbLp9M/l1mvEvMcUiw3ijinVD9uZZWuu35vmfVj3hN6jNi8tlJ59Osj3Wfmccz73uv5TdK97vas6dul59CNdO9JxzXj8tJYDf045vuk4/od/yFdH8B5jaTr43Zcr9fA7fnW70vpGTHPYY6LksYCSMeR7jVp/IXb2Axp/JfNddafE5t7QvoOcbtH0nEbf2Lzb1m695mCPN/psOeDiIiIIsXGBxEREUUqI6vaBpFMJsVuSJtucynsYRMSMUMX+nvNbnSv7zO7sczuSv1zSt3L0vkBZ5euWT82Xb9SN2hnZ2fa95lhF33qmTT1WrrOgPd6N69lc3Oz47U+VdCsH51bmEyil0fqsrUJ95n8huJsnhGT3uXu9zoD3sNkbqEDqWvYb92Zx9Q/lznNVPoOMc+hf07pnnQjdevrz+nrr78ulsfrOdxI9Txt2rTU7+YUfPO6S99b+n0pfaeZgoSrpXN6DWXYfFdLpLoySf/OuJ1Per7dptr29vaiuLhY3Ic9H0RERBQpNj6IiIgoUmx8EBERUaSG3ZgPG9I0UykmK8X+TWZcXI/HSed3K48Ul9fPYdaVFAO1KY9UPvOcemzQjFeb40FmzJiR+t0tBqp7+OGHHa9vv/32tPtKn9OsSylmrrOJCZu8jidyG1fidcquNI7DjRTblvb1O97CJI2NcLvX05XNPI7bcyA9M9I2qe78jj9zI43/snkOpHPqz7D5fJtjAaT06l7HcZjbbb4bJdI53cZR2Izt80q6R8M6h3ROc+yeNNaIYz6IiIho2GPjg4iIiCLFxgcRERFFakSP+TD5HUdh8jpWwSZuaJMC22aeu99xAn7jjzapvW3OYZNnw+ZzeS2fW3m8jhmyuT4mm339Hsfm2up5G37xi184tukx4iBjPnQ2OR1sxjvYPJdel0hwG59iM95AOo7XMrjlppDyUejHMdOp244F8MNvniXAe10Gye0UxXMpsfmutHkuJG7l45gPIiIiyjlsfBAREVGkRlXYRRekq0qaduo1lGDT9Ssdx6270GsKdbfud72L0uxitwlZ6d22bmm3/ZK6L/3eW27H8dqFK3V3u51T4neqrc3UX5twgD49M0h6dZtr6XVabpBnxuu97nbtwgoBe52WaxN+dFspViJ9bmmlbpPXEESQ73Gb8HUYIWqb+85mKr/fcLpbaFCfemumTGDYhYiIiIYdNj6IiIgoUmx8EBERUaQKsl0AWzZptk1eU9SasTCbZdCl94UV0wuLFP+zmUbodUl04NipeukEuc5SfdmMf9C5XS/9tXSfmbHuTEyvcytrJu4nKW29G30Mkc11luL05meUxnHYjHsJa6yG1xTubmMRvN6zQZaMkFLB25DGjtiMndPZLEkg7ev2ufym3Pe6zdxuM94qyDl10jPc09Pj2BbGeD32fBAREVGk2PggIiKiSHGq7f/xO9VW2tdvtlHA+zSwsFb6NIU1vTjIdES/bFaK9drF7jZNTmfuu3HjxtTvCxYscGyzyaIqCStbrX7/mCEQ6V6Tumxtsl6GFVb1W3dAOCvFBskcGxYpw6nf7yZz5VPztddr7fYdkonMzJIgK5B7PU6QqbZ+s+mGNdXWDJlJqxcfPQ+n2hIREVFOYeODiIiIImXd+Hj11Vdx2WWXoby8HHl5eWhtbXVsV0rhvvvuw4knnohx48ahtrYWu3fvDqu8RERENMxZj/lYv349XnvtNVRXV+Oqq67CunXrHHHsBx54AE1NTXjiiSdQVVWFe++9F2+++SZ27dqFoqIi1+OHOebDb9zMJm4oxS6DrMboN0ZsUx79OM3NzY5t5lQqv7FuszxexwLYrCIbRUzYjdexPjarm/q9f9xiuX7HLJnjMfQpstJqp2Gl0s4GvzH8IOM4wjqOxG9qepM5FkC61vr94zaWJ+o6sJmK7Pe6m9+xM2bMcLz2+nwHWWHbZuyevt0s66ZNm9KW7+ix3MZ8WOf5mD9/PubPnz/kNqUUVqxYgZ/97Ge44oorAAB/+MMfUFpaitbWVlxzzTW2pyMiIqIRJtQxH3v27EEikUBtbW3qb/F4HHPmzEF7e/uQ7+nv70cymXT8EBER0cgVauMjkUgAAEpLSx1/Ly0tTW0zNTU1IR6Pp36mTp0aZpGIiIgox2Q9vfrSpUuxZMmS1OtkMhlaA8QmNieNjZCYYwr0OJmUptgtjul3fIo0bkA6ptsYDynviBQv1scFAPIS8pngNgbFZql1ide002Hlk5HitVJqcfO90hgUm9iyme/Baxp9kzSuxG1MjNexAUFSTuv1ZbNEgk1uHL9p/m1yi0ik8we5zvo4jyDp59PtFxWp7NJ1NuvO77gOt8/sNx2+9LnM6xxGvYfa81FWVgYA6O7udvy9u7s7tc0Ui8VQXFzs+CEiIqKRK9TGR1VVFcrKytDW1pb6WzKZxNatW1FTUxPmqYiIiGiYsp5q+8knn+Ddd98F8GXXy0MPPYS5c+di4sSJqKiowAMPPIDly5c7ptq+8cYbkUy19bsyIiBPzzT57day6WbzO61SOmcUKYTDqg830ufORIp5m9UzbVJFS/ymuLdZPVPidi31c5qrF+vhtbCmrttMDZTKagqrnvX7zgw3ZmMKcSbSzZtTLjs7Ox2vc22qdCbY/Pvk9ztfOk6QZRh0Ns+TzXT9o8cKfartjh07MHfu3NTro+M16urq0NLSgjvuuAOffvopbrzxRvT09OD888/Hiy++6KnhQURERCOfdePjwgsvhNRZkpeXh/vvvx/3339/oIIRERHRyMS1XYiIiChSWZ9qG5TXWPtQ23U20xr9ToOVYrB+p2BmKk27DSlG7pdbPFQaCyBN0zPXItLTgNvUR5AxF+n4jQGbgoxZslnWW99uxv5t+B0vM23aNMdr/drapOMPi1vKcF0U6cNteF2SwLzO+vNj7jucuP1b4fV62Xzf+p1SHWT8js2/HenOHxb2fBAREVGk2PggIiKiSLHxQURERJGyzvORaW55Psx8AvoSzkHyfITFbyzXJkeBzVLQYaWc9nrOsPJ8mGm2zXi633ECUnp1vzlAhnpvuuPY5HOxyWXid6yG39TrgDOXRWVlZdqySsusA/JS62GN1YhizIcuyJglv3lRwvpcNnlPTFIZ9O9ut3siCjZ5LPzmFbIZS2iz/IYuyLIMOmmci81349Htbnk+2PNBREREkWLjg4iIiCI17MIupkykrrYhdZ1Jenp6HK/NKWte06vbsOneDWu6qPm59PoxV9KVZCqk5jXMEFa3rE0IxOR13yDXWQpD7d271/FaD7WY+7a0tKR+t+lit+netu0KTsfvve63C9vmHEH4DT/67bYf6jw6r+E+t+OEJaywi9clLIKEdaWyhfle3QUXXJD63TaNPsMuRERElHPY+CAiIqJIsfFBREREkRp2Yz7cpmDq/MYRbd7ndxnvILFTryncAf/xa5tpnzpzjIc5tkU6h87mOgfhdQqxWwxYGiuh85u22WbfIGnapXvLnOa+ePHitMddsWJF6nebsT2mqKfImqTnKaz4vs35JTYp5aXn3Wabye+YD0mQ7+NcTvceZMyQxG/697DSJBzdzjEfRERElFPY+CAiIqJIDbuwi8mmW9ZmapVfXrvSbMIcQcqaidUzpW59t+mQmcjiZ5OR0e90bL9T70x+3+fG5jo3NzenfjdDIlL59IymgPPaSqvaBplSrbOZDuk3S6jNe4Nkq/X7HEhlyNR3mF5Wt9WLvV7rqJ5Lv9cyyJT4qNncW/rnMKfOm6/N513HsAsRERENO2x8EBERUaTY+CAiIqJIDbsxH1GkKQ5yDr+xQb/x4yhi2zbHdYt1e027bXMNbNIJS3UQZOxItqeE6oKUVR8PMmPGDPE8ev3o7wOAxsbGtOcwSava6oKs3un3OQlrbFim7uewSCuY2sj2qrZR/PuQqRQOmfheD2tqrYljPoiIiGjYYeODiIiIIsXGBxEREUVq2I35CCLqWKrf+ddu5YlifIHfsptxRDMvgD6OIIqxEZmKAftNQW0T5w2r7DapmaWcLdL4GT2dOuDMEbBv3z7P5Quy7HhY+V100nNg8juOQzqO2/uksWHS+YOkN5dkIr16WN/bmRrL5/Vzud2TXvOX2KRel54nm2UYTBzzQURERMOOVeOjqakJZ599NiZMmIDJkydjwYIF6Orqcuxz6NAh1NfXY9KkSRg/fjwWLlyI7u7uUAtNREREw1eBzc6bNm1CfX09zj77bHzxxRe4++678Z3vfAe7du3C8ccfD+DLqXLPP/881q5di3g8joaGBlx11VV47bXXMvIBbEjdmZmY5mQeR+oWznY6X5suOJuwgtt0zXSiWL3YjXRP2Kz6K5VHCmfZdOPbdNVL5/A7zd1Mq20zrTITIcYgK59KdRBWPUuhJpswlNcud5tVmaXjZIrNM+z3+1CqZ7djStu9lscm5b5NqEu6R8zj6K9trqueIgEIZ6q0VePjxRdfPKZAkydPRkdHB771rW+ht7cXjz/+OFavXo2LLroIwJfz/0899VRs2bIF55xzTuACExER0fAWaMzH0RbWxIkTAQAdHR04cuQIamtrU/tMnz4dFRUVaG9vH/IY/f39SCaTjh8iIiIauXw3PgYHB9HY2IjzzjsPZ5xxBgAgkUigsLAQJSUljn1LS0uRSCSGPE5TUxPi8XjqZ+rUqX6LRERERMOA76m2N998M9avX4/NmzdjypQpAIDVq1fj+uuvR39/v2Pf2bNnY+7cuXjggQeOOU5/f79j/2QyKTZAzLEAOmlcgI2wlns22Sy5rZOmPLrFBv2m/jXPKcUHvU4zNY8bZLpf1Knh3c4npeT2WndBlvWOIq2/SX8WzeW39am3bsus+51qKx3HZiyLdN2laxnkHvSaztxm+rXfbYBcP/q1bW1tdWxbvHix47XbtfbD5tn3ez9n6ns0yDm9shkH5HfJCLdzDrXdbaqt1ZiPoxoaGvDcc8/h1VdfTTU8AKCsrAyHDx9GT0+Po/eju7sbZWVlQx4rFoshFov5KQYRERENQ1ZhF6UUGhoasG7dOmzYsAFVVVWO7dXV1Rg7diza2tpSf+vq6sL+/ftRU1MTTomJiIhoWLMKu9xyyy1YvXo1nnnmGZxyyimpv8fjcYwbNw7Al+GYF154AS0tLSguLsatt94KAPjHP/7h6RxBMpyGtWpgrq+c63fVzSgy/C1YsMCxraenx/Hab9glU8LKTuh3SmhY4RKpO9W8R/RrMnfuXMe2TIT/3OpDmpabiSnw2Q7hhXnOsMI3Oin0pV8r4NjnWwq7eL3OQ5U3E2zCdF7/vQgSNpRk4jvf7d8DfaVqfZVqIAthl1WrVgE4toKbm5tTsb+HH34Y+fn5WLhwIfr7+zFv3jw8+uijNqchIiKiEcyq8eGlk6SoqAgrV67EypUrfReKiIiIRi6u7UJERESRGvar2trEhL1OQQoy5VESxrgA85xBptBJcXAp3bJJms5mlk9f5dZmWp7fMQVBYslhXS+dTXzW7wqvbu8La4VMXaampNqU1et96PbMZmJpA5tVosM6Tlhp2nVuKbm9Xmub5RNsBJmWK/F7b4U1HsRmPJHXe8tmxWSbcSZHy8hVbYmIiCinsPFBREREkWLjg4iIiCI17MZ82OQ3MHmd6x8k7a3fdLphxRH1tNaAc1yFTd1lIy2wzi0m7Hc8Rq4J63N4zfcAyDFhqTxmjgcztbbOb3r1TI2xsBnfZPOdku44NkvYh5XfRRoD45byX99us0S7yesYiyB5NSQ2deA31bh0HL/XxxRkHJvX58lmvKC5Xtu+ffvSHvdomTjmg4iIiHIKGx9EREQUqWEXdrGRqRUgvZ4zU+myMzW11Gt5pHOa3XPmaqd6mt5MpJgGoknb7rd71+R35dywQjQ2U7P9skmvXllZ6dgmPZeSIKnypX29vi+IKK6zzfeN/gybz7Ntd3y6c2Q7dBoklBHWatN+QztSCE16vs1zmEtj6Nc6E+nV2fNBREREkWLjg4iIiCLFxgcRERFFymphuVxnxu2kGJvfbW68jtWwKWtYy7dLbNKr24yxMGOFXqcYmlNtzViz/rkzNcbDZnl5nc30VZvz+50q7vcecZueGdb09L1796Z+l5Zaj+o665/LLZV1um1BrrPfabhhsVlawZzab14/nf5MS1PngcyMB7N5Lv3+WyK9z+0zSt9pNs+B1/Ehbuneze/usLHng4iIiCLFxgcRERFFathPtbXp6pw2bVrq99bWVsc2m+6xMFYqDNLN6DdToE24xGYqqXTMKKbUhTXVNqxQhs1qq16n17mdI6xVmf1O/TXZZDjVmVlU9W58v6uQAv6n09qEIHQ20yrDWvU3U9Mz050P8D/VNsiqtplIZ2D73nRspq7brARtM+Xb6z3h9u+s16y36crEqbZERESUU9j4ICIiokix8UFERESRGnZjPqRYYTbSJoc1TsBmul1Yq0NKsUGTtAKjVNbm5mbHa7/p1W3GAWUi/bzNqpd+V5kMi1t82iam71VLS4vjtT591i2ebzMFUyeNN8iF74IojhPWvSXdM26rJOv8jl/J1HPp9b1SvQLhPDM2U9elc7gdx88xh9o3yP3MMR9ERESUc9j4ICIiokix8UFERESRGnZjPkw2KXK95rwIK45okuZNh7UMuzR2xCY26HfJePN9nZ2djtczZswY8n3ZEsZS2eb2TGwLs6z6tTTHTfT09KR+N+8Jc/yOfi39xp1tuNWP17i4TXzf5jvE7/MVJN+E9LkkYaXu95ubxu0zZyKXh9/8O4D3MXg2+VPC+o71Wx63MSf6WCxpeYuhcMwHERER5RyrxseqVatw1llnobi4GMXFxaipqcH69etT2w8dOoT6+npMmjQJ48ePx8KFC9Hd3R16oYmIiGj4sgq7PPvssxgzZgxOPvlkKKXwxBNP4MEHH8TOnTtx+umn4+abb8bzzz+PlpYWxONxNDQ0ID8/H6+99prnArmFXcJKpW3TPWfyes6wpodKx7UJNQUpj/5eswtO6m6Oojs+GzIxrdJmqpvZLbtgwYLU73roxI3NlGG/03Dd6kPq3rU5jk5KA27WTyamZ9pMgQ/ruQzyPmmqfban2gZJZxDWqrZ+yyOVzeT3WvoN/7m9T58+v3jxYrEMQ53HLexSIG41XHbZZY7Xy5Ytw6pVq7BlyxZMmTIFjz/+OFavXo2LLroIwJcx4lNPPRVbtmzBOeecY3MqIiIiGqF8j/kYGBjAmjVr8Omnn6KmpgYdHR04cuQIamtrU/tMnz4dFRUVaG9vT3uc/v5+JJNJxw8RERGNXNaNjzfffBPjx49HLBbDTTfdhHXr1uG0005DIpFAYWHhMaPnS0tLkUgk0h6vqakJ8Xg89TN16lTrD0FERETDh1XYBQBOOeUUdHZ2ore3F3/5y19QV1eHTZs2+S7A0qVLsWTJktTrZDIpNkDMeO3GjRtTv8+dO9exzYyhtba2pj2OHnd2i9fqr/VjDlWGdMzYts2S0plI1WzuJ417sVk22pxqq8eWpaXWzTT60lQv/R4AnOMfbOKsUtzX5jr7jeHbpDuWxh1J9WGyKatZB9Jx9eteVVXl+RxBSNdSquewxmbp3ylux9SvkfmM2HwXeC2b272l30/m/SPRp1sDEP8t8Dq2x2QzLsmkj2MI6143t+n/4Za+G/2OmXJjk/pd+g4x99XHedhcZ6+sGx+FhYU46aSTAADV1dXYvn07HnnkEVx99dU4fPgwenp6HBeju7sbZWVlaY8Xi8UQi8XsS05ERETDUuA8H4ODg+jv70d1dTXGjh2Ltra21Lauri7s378fNTU1QU9DREREI4RVz8fSpUsxf/58VFRUoK+vD6tXr8Yrr7yCl156CfF4HDfccAOWLFmCiRMnori4GLfeeitqamo404WIiIhSrBofBw8exHXXXYcPP/wQ8XgcZ511Fl566SV8+9vfBvBlTC8/Px8LFy5Ef38/5s2bh0cffTTUApvxUT126BZT098rxRxtYnNm3Nvr8uDm+cPKAyCNJbGJdUtzxW3KY+YPMAck+6WXwbwG0rX1e93Nc/iN35rntymPdE/oxzHLKsXFpbE15jbz2ZNIeSMkQerH6zabnA4mfTyEWc/6a7cxDWGNw/H6PrfPLN0/+jbzczU2Njpe+x0LYDNeRR9zIdWjDfMc0nGl+jGfGen7VyJdr7DGrrgdZ8WKFanfzTGSkY/5ePzxx8XtRUVFWLlyJVauXBmoUERERDRycW0XIiIiipT1bJdco3crSSEZwDldSO9SApzdh9L7ALlLTgoD2Uwv05mhCn2KqtnNZ5ZVCvXor91WMK2srEz9btP9bpPqWyKFk6SpgW4hCLP+vJ5fIt1b0nHdzqFfE7NepWvpNm3Zb3mkffXUzDbM40jPt3mv61MDpfqRzgE47xnzWnrt5repuyCkbn0pPCF9F9iUze/z7RZe01+b339ev3+Heq2T7i2Jed/t3bs39buUCsImRGTzb4e5Ta8v6XNJ7zOFFS7XseeDiIiIIsXGBxEREUWKjQ8iIiKKVJ5SSmW7ELpkMpmxNLRERESUWb29vSguLhb3Yc8HERERRYqNDyIiIooUGx9EREQUKTY+iIiIKFJsfBAREVGkcq7xkWOTb4iIiMiCl3/Hc67x0dfXl+0iEBERkU9e/h3PuTwfg4OD+OCDD6CUQkVFBQ4cOOA6X3g0SiaTmDp1KusnDdaPjPUjY/3IWD/pjea6UUqhr68P5eXlyM+X+zZybmG5/Px8TJkyBclkEgBQXFw86i6gDdaPjPUjY/3IWD8y1k96o7VuvCYJzbmwCxEREY1sbHwQERFRpHK28RGLxfDzn/8csVgs20XJSawfGetHxvqRsX5krJ/0WDfe5NyAUyIiIhrZcrbng4iIiEYmNj6IiIgoUmx8EBERUaTY+CAiIqJIsfFBREREkcrZxsfKlStRWVmJoqIizJkzB9u2bct2kSLX1NSEs88+GxMmTMDkyZOxYMECdHV1OfY5dOgQ6uvrMWnSJIwfPx4LFy5Ed3d3lkqcXcuXL0deXh4aGxtTfxvt9fP+++/jBz/4ASZNmoRx48bhzDPPxI4dO1LblVK47777cOKJJ2LcuHGora3F7t27s1ji6AwMDODee+9FVVUVxo0bh69//ev45S9/6VgUazTVz6uvvorLLrsM5eXlyMvLQ2trq2O7l7r4+OOPsWjRIhQXF6OkpAQ33HADPvnkkwg/ReZI9XPkyBHceeedOPPMM3H88cejvLwc1113HT744APHMUZy/VhTOWjNmjWqsLBQ/f73v1f/+te/1I9//GNVUlKiuru7s120SM2bN081Nzert956S3V2dqrvfve7qqKiQn3yySepfW666SY1depU1dbWpnbs2KHOOeccde6552ax1Nmxbds2VVlZqc466yx12223pf4+muvn448/VtOmTVOLFy9WW7duVe+995566aWX1LvvvpvaZ/ny5Soej6vW1lb1+uuvq8svv1xVVVWpzz//PIslj8ayZcvUpEmT1HPPPaf27Nmj1q5dq8aPH68eeeSR1D6jqX5eeOEFdc8996inn35aAVDr1q1zbPdSF5dccon6xje+obZs2aL+/ve/q5NOOklde+21EX+SzJDqp6enR9XW1qqnnnpKvf3226q9vV3Nnj1bVVdXO44xkuvHVk42PmbPnq3q6+tTrwcGBlR5eblqamrKYqmy7+DBgwqA2rRpk1Lqyxt+7Nixau3atal9/v3vfysAqr29PVvFjFxfX586+eST1csvv6wuuOCCVONjtNfPnXfeqc4///y02wcHB1VZWZl68MEHU3/r6elRsVhM/elPf4qiiFl16aWXqh/96EeOv1111VVq0aJFSqnRXT/mP65e6mLXrl0KgNq+fXtqn/Xr16u8vDz1/vvvR1b2KAzVODNt27ZNAVD79u1TSo2u+vEi58Iuhw8fRkdHB2pra1N/y8/PR21tLdrb27NYsuzr7e0FAEycOBEA0NHRgSNHjjjqavr06aioqBhVdVVfX49LL73UUQ8A6+evf/0rZs2ahe9973uYPHkyZs6cid/97nep7Xv27EEikXDUTzwex5w5c0ZF/Zx77rloa2vDO++8AwB4/fXXsXnzZsyfPx8A60fnpS7a29tRUlKCWbNmpfapra1Ffn4+tm7dGnmZs623txd5eXkoKSkBwPox5dyqth999BEGBgZQWlrq+HtpaSnefvvtLJUq+wYHB9HY2IjzzjsPZ5xxBgAgkUigsLAwdXMfVVpaikQikYVSRm/NmjX45z//ie3btx+zbbTXz3vvvYdVq1ZhyZIluPvuu7F9+3b85Cc/QWFhIerq6lJ1MNSzNhrq56677kIymcT06dMxZswYDAwMYNmyZVi0aBEAjPr60Xmpi0QigcmTJzu2FxQUYOLEiaOuvg4dOoQ777wT1157bWplW9aPU841Pmho9fX1eOutt7B58+ZsFyVnHDhwALfddhtefvllFBUVZbs4OWdwcBCzZs3Cr371KwDAzJkz8dZbb+Gxxx5DXV1dlkuXfX/+85/x5JNPYvXq1Tj99NPR2dmJxsZGlJeXs37ItyNHjuD73/8+lFJYtWpVtouTs3Iu7HLCCSdgzJgxx8xI6O7uRllZWZZKlV0NDQ147rnnsHHjRkyZMiX197KyMhw+fBg9PT2O/UdLXXV0dODgwYP45je/iYKCAhQUFGDTpk349a9/jYKCApSWlo7q+jnxxBNx2mmnOf526qmnYv/+/QCQqoPR+qz99Kc/xV133YVrrrkGZ555Jn74wx/i9ttvR1NTEwDWj85LXZSVleHgwYOO7V988QU+/vjjUVNfRxse+/btw8svv5zq9QBYP6aca3wUFhaiuroabW1tqb8NDg6ira0NNTU1WSxZ9JRSaGhowLp167BhwwZUVVU5tldXV2Ps2LGOuurq6sL+/ftHRV1dfPHFePPNN9HZ2Zn6mTVrFhYtWpT6fTTXz3nnnXfM1Ox33nkH06ZNAwBUVVWhrKzMUT/JZBJbt24dFfXz2WefIT/f+RU4ZswYDA4OAmD96LzURU1NDXp6etDR0ZHaZ8OGDRgcHMScOXMiL3PUjjY8du/ejb/97W+YNGmSY/tor59jZHvE61DWrFmjYrGYamlpUbt27VI33nijKikpUYlEIttFi9TNN9+s4vG4euWVV9SHH36Y+vnss89S+9x0002qoqJCbdiwQe3YsUPV1NSompqaLJY6u/TZLkqN7vrZtm2bKigoUMuWLVO7d+9WTz75pDruuOPUH//4x9Q+y5cvVyUlJeqZZ55Rb7zxhrriiitG7FRSU11dnfrqV7+ammr79NNPqxNOOEHdcccdqX1GU/309fWpnTt3qp07dyoA6qGHHlI7d+5MzdbwUheXXHKJmjlzptq6davavHmzOvnkk0fMVFKpfg4fPqwuv/xyNWXKFNXZ2en4vu7v708dYyTXj62cbHwopdRvfvMbVVFRoQoLC9Xs2bPVli1bsl2kyAEY8qe5uTm1z+eff65uueUW9ZWvfEUdd9xx6sorr1Qffvhh9gqdZWbjY7TXz7PPPqvOOOMMFYvF1PTp09Vvf/tbx/bBwUF17733qtLSUhWLxdTFF1+surq6slTaaCWTSXXbbbepiooKVVRUpL72ta+pe+65x/GPxWiqn40bNw75fVNXV6eU8lYX//3vf9W1116rxo8fr4qLi9X111+v+vr6svBpwifVz549e9J+X2/cuDF1jJFcP7bylNLS+RERERFlWM6N+SAiIqKRjY0PIiIiihQbH0RERBQpNj6IiIgoUmx8EBERUaTY+CAiIqJIsfFBREREkWLjg4iIiCLFxgcRERFFio0PIiIiihQbH0RERBSp/wFoQYQXP1ktRwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0.0026, norm: 0.0014, param_norm: 935.4885, vb_loss: 0.0008, ce_loss: 0.7653:  19%|█▉        | 298/1563 [00:23<00:50, 25.17it/s]"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVCUlEQVR4nO3df2xV9f3H8VdL29tqube0rPfSwZVuklUHONZCuWL2y7sAMwqj24SwUZXMwIqjNpmIDvbHxkq2ZKKLYrZlsGUyHIngZFPCCsJYaks76qyVirGRRrgXHfbegtJi72d/fOP9cvlRekv7ube9z0fySeg5h3vffX9o74vPOefeNGOMEQAAgCXpiS4AAACkFsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsGrYwseTTz6pyZMnKzs7W+Xl5WpsbByupwIAACNI2nB8tsuzzz6rZcuW6emnn1Z5ebk2bdqkHTt2qL29XYWFhf3+3UgkohMnTmjs2LFKS0sb6tIAAMAwMMaou7tbRUVFSk+/ytqGGQazZs0yVVVV0a/7+vpMUVGRqa2tverf7ezsNJIYDAaDwWCMwNHZ2XnV1/ohP+3S29ur5uZm+f3+6Lb09HT5/X7V19dfcnxPT4/C4XB0GD5kFwCAEWvs2LFXPWbIw8f777+vvr4+ud3umO1ut1uBQOCS42tra+VyuaLD6/UOdUkAAMCSgVwykfC7XdauXatQKBQdnZ2diS4JAAAMo4yhfsDx48drzJgxCgaDMduDwaA8Hs8lxzscDjkcjqEuAwAAJKkhX/nIyspSaWmp6urqotsikYjq6urk8/mG+ukAAMAIM+QrH5JUU1OjyspKlZWVadasWdq0aZPOnj2re++9dzieDgAAjCDDEj7uvvtuvffee1q/fr0CgYC+8IUv6KWXXrrkIlQAAJB6huVNxq5FOByWy+VKdBkAAGAQQqGQnE5nv8ck/G4XAACQWggfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwKqMRBeA+H3wwQfD/hzjxo0b9udAfIZj3pnnxLPx83wx5t0+5jkWKx8AAMAqwgcAALAq7vBx8OBB3XnnnSoqKlJaWpp27doVs98Yo/Xr12vChAnKycmR3+/XsWPHhqpeAAAwwsV9zcfZs2d1yy236L777tOiRYsu2f+LX/xCTzzxhP7whz+ouLhY69at09y5c9XW1qbs7OwhKXo0SsT5QAyPkTSXF9eazOeIk81Immfm9dow10Mv7vAxf/58zZ8//7L7jDHatGmTfvzjH2vBggWSpD/+8Y9yu93atWuXFi9efG3VAgCAEW9Ir/no6OhQIBCQ3++PbnO5XCovL1d9ff1l/05PT4/C4XDMAAAAo9eQho9AICBJcrvdMdvdbnd038Vqa2vlcrmiY9KkSUNZEgAASDIJf5+PtWvXqqamJvp1OBwe0QEk2c4NjpTzfyNNoueZebWDeU4NiZ5nKfXmekhXPjwejyQpGAzGbA8Gg9F9F3M4HHI6nTEDAACMXkMaPoqLi+XxeFRXVxfdFg6H1dDQIJ/PN5RPBQAARqi4T7ucOXNGb731VvTrjo4OtbS0KD8/X16vV9XV1frZz36mKVOmRG+1LSoq0sKFC4eyblxBqi3d2ZIMy7IAMFrEHT6ampr01a9+Nfr1J9drVFZWauvWrXrooYd09uxZ3X///erq6tJtt92ml156iff4AAAAkqQ0Y4xJdBEXCofDcrlciS5j0BL9P2RWPoZHouf1YsyzHYmed+bZjkTPszS65joUCl31+k0+2wUAAFjFygcAABgyrHwAAICkQ/gAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiVkegCAGC0++CDD664b9y4cRYrAZIDKx8AAMAqwgcAALCK8AEAAKzimg8AGGZc1wHEYuUDAABYFVf4qK2t1cyZMzV27FgVFhZq4cKFam9vjznm3LlzqqqqUkFBgXJzc1VRUaFgMDikRQMAgJErzRhjBnrwvHnztHjxYs2cOVMff/yxHnnkEbW2tqqtrU3XX3+9JGnlypX629/+pq1bt8rlcmnVqlVKT0/Xv/71rwE9RzgclsvlGtx3k4L6u4XvYiz9jl4X/jtgnkc25hIjXSgUktPp7PeYuMLHxd577z0VFhbqwIED+tKXvqRQKKRPfepT2rZtm771rW9Jko4ePaqbbrpJ9fX1mj179lUfk/ARH8IHJF6wRhPmEiPdQMLHNV3zEQqFJEn5+fmSpObmZp0/f15+vz96TElJibxer+rr6y/7GD09PQqHwzEDAACMXoMOH5FIRNXV1ZozZ46mTp0qSQoEAsrKylJeXl7MsW63W4FA4LKPU1tbK5fLFR2TJk0abEkAAGAEGPSttlVVVWptbdWhQ4euqYC1a9eqpqYm+nU4HCaAAJcRzyk2jA4XzzmnYUaPVH/L/UGFj1WrVmn37t06ePCgJk6cGN3u8XjU29urrq6umNWPYDAoj8dz2cdyOBxyOByDKQMAAIxAcZ12McZo1apV2rlzp/bt26fi4uKY/aWlpcrMzFRdXV10W3t7u44fPy6fzzc0FQMAgBEtrpWPqqoqbdu2Tc8//7zGjh0bvY7D5XIpJydHLpdLy5cvV01NjfLz8+V0OvXAAw/I5/MN6E4XAP+PO5mA0YPTprHiCh+bN2+WJH3lK1+J2b5lyxbdc889kqTHHntM6enpqqioUE9Pj+bOnaunnnpqSIoFAAAj3zW9z8dw4H0+4sP/jkcv5jY1pfqFiKNVKv08D/v7fAAAAMSLT7VNUkN1fnCkJ+jRhvO+qYl5H72Y28Fh5QMAAFhF+AAAAFYRPgAAgFVc85Egw3WekGs8ks9wzDXznHyY59HLxnUdqTbXrHwAAACrCB8AAMAqTrsMI06tpA6W3FMD85wamOfhx8oHAACwivABAACsInwAAACruObjGnFdR2pgnlMH5/tTA/OcWKx8AAAAqwgfAADAKk67DMJQLNexPJd8OLWSGpjn1MGpleTFygcAALCK8AEAAKwifAAAAKu45mMYcW4w+Q3VOWHmOrkxz6mBeR45WPkAAABWET4AAIBVhA8AAGBVmjHGJLqIC4XDYblcrkSXAQAABiEUCsnpdPZ7DCsfAADAqrjCx+bNmzV9+nQ5nU45nU75fD69+OKL0f3nzp1TVVWVCgoKlJubq4qKCgWDwSEvGgAAjFxxhY+JEydq48aNam5uVlNTk772ta9pwYIFev311yVJDz74oF544QXt2LFDBw4c0IkTJ7Ro0aJhKRwAAIxQ5hqNGzfO/O53vzNdXV0mMzPT7NixI7rvjTfeMJJMfX39gB8vFAoZSQwGg8FgMEbgCIVCV32tH/Q1H319fdq+fbvOnj0rn8+n5uZmnT9/Xn6/P3pMSUmJvF6v6uvrr/g4PT09CofDMQMAAIxecYeP1157Tbm5uXI4HFqxYoV27typm2++WYFAQFlZWcrLy4s53u12KxAIXPHxamtr5XK5omPSpElxfxMAAGDkiDt8fO5zn1NLS4saGhq0cuVKVVZWqq2tbdAFrF27VqFQKDo6OzsH/VgAACD5xf3ZLllZWbrxxhslSaWlpTp8+LAef/xx3X333ert7VVXV1fM6kcwGJTH47ni4zkcDjkcjvgrBwAAI9I1v89HJBJRT0+PSktLlZmZqbq6uui+9vZ2HT9+XD6f71qfBgAAjBJxrXysXbtW8+fPl9frVXd3t7Zt26aXX35Ze/bskcvl0vLly1VTU6P8/Hw5nU498MAD8vl8mj179nDVDwAARpi4wsepU6e0bNkynTx5Ui6XS9OnT9eePXv09a9/XZL02GOPKT09XRUVFerp6dHcuXP11FNPDUvhAABgZOKzXQAAwJDhs10AAEDSIXwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqqQLH0n2UTMAACAOA3kdT7rw0d3dnegSAADAIA3kdTzpPtU2EonoxIkTMsbI6/Wqs7Pzqp+Ol4rC4bAmTZpEf66A/vSP/vSP/vSP/lxZKvfGGKPu7m4VFRUpPb3/tY0MSzUNWHp6uiZOnKhwOCxJcjqdKTeB8aA//aM//aM//aM//aM/V5aqvXG5XAM6LulOuwAAgNGN8AEAAKxK2vDhcDj0k5/8RA6HI9GlJCX60z/60z/60z/60z/6c2X0ZmCS7oJTAAAwuiXtygcAABidCB8AAMAqwgcAALCK8AEAAKwifAAAAKuSNnw8+eSTmjx5srKzs1VeXq7GxsZEl2RdbW2tZs6cqbFjx6qwsFALFy5Ue3t7zDHnzp1TVVWVCgoKlJubq4qKCgWDwQRVnFgbN25UWlqaqquro9tSvT/vvvuuvvvd76qgoEA5OTmaNm2ampqaovuNMVq/fr0mTJignJwc+f1+HTt2LIEV29PX16d169apuLhYOTk5+uxnP6uf/vSnMR+KlUr9OXjwoO68804VFRUpLS1Nu3btitk/kF6cPn1aS5culdPpVF5enpYvX64zZ85Y/C6GT3/9OX/+vNasWaNp06bp+uuvV1FRkZYtW6YTJ07EPMZo7k/cTBLavn27ycrKMr///e/N66+/br7//e+bvLw8EwwGE12aVXPnzjVbtmwxra2tpqWlxXzjG98wXq/XnDlzJnrMihUrzKRJk0xdXZ1pamoys2fPNrfeemsCq06MxsZGM3nyZDN9+nSzevXq6PZU7s/p06fNDTfcYO655x7T0NBg3n77bbNnzx7z1ltvRY/ZuHGjcblcZteuXebVV181d911lykuLjYfffRRAiu3Y8OGDaagoMDs3r3bdHR0mB07dpjc3Fzz+OOPR49Jpf78/e9/N48++qh57rnnjCSzc+fOmP0D6cW8efPMLbfcYl555RXzz3/+09x4441myZIllr+T4dFff7q6uozf7zfPPvusOXr0qKmvrzezZs0ypaWlMY8xmvsTr6QMH7NmzTJVVVXRr/v6+kxRUZGpra1NYFWJd+rUKSPJHDhwwBjzf//gMzMzzY4dO6LHvPHGG0aSqa+vT1SZ1nV3d5spU6aYvXv3mi9/+cvR8JHq/VmzZo257bbbrrg/EokYj8djfvnLX0a3dXV1GYfDYf785z/bKDGh7rjjDnPffffFbFu0aJFZunSpMSa1+3Pxi+tAetHW1mYkmcOHD0ePefHFF01aWpp59913rdVuw+XC2cUaGxuNJPPOO+8YY1KrPwORdKddent71dzcLL/fH92Wnp4uv9+v+vr6BFaWeKFQSJKUn58vSWpubtb58+djelVSUiKv15tSvaqqqtIdd9wR0weJ/vz1r39VWVmZvv3tb6uwsFAzZszQb3/72+j+jo4OBQKBmP64XC6Vl5enRH9uvfVW1dXV6c0335Qkvfrqqzp06JDmz58vif5caCC9qK+vV15ensrKyqLH+P1+paenq6GhwXrNiRYKhZSWlqa8vDxJ9OdiSfeptu+//776+vrkdrtjtrvdbh09ejRBVSVeJBJRdXW15syZo6lTp0qSAoGAsrKyov+4P+F2uxUIBBJQpX3bt2/Xv//9bx0+fPiSfanen7ffflubN29WTU2NHnnkER0+fFg//OEPlZWVpcrKymgPLvezlgr9efjhhxUOh1VSUqIxY8aor69PGzZs0NKlSyUp5ftzoYH0IhAIqLCwMGZ/RkaG8vPzU65f586d05o1a7RkyZLoJ9vSn1hJFz5weVVVVWptbdWhQ4cSXUrS6Ozs1OrVq7V3715lZ2cnupykE4lEVFZWpp///OeSpBkzZqi1tVVPP/20KisrE1xd4v3lL3/RM888o23btunzn/+8WlpaVF1draKiIvqDQTt//ry+853vyBijzZs3J7qcpJV0p13Gjx+vMWPGXHJHQjAYlMfjSVBVibVq1Srt3r1b+/fv18SJE6PbPR6Pent71dXVFXN8qvSqublZp06d0he/+EVlZGQoIyNDBw4c0BNPPKGMjAy53e6U7s+ECRN08803x2y76aabdPz4cUmK9iBVf9Z+9KMf6eGHH9bixYs1bdo0fe9739ODDz6o2tpaSfTnQgPphcfj0alTp2L2f/zxxzp9+nTK9OuT4PHOO+9o79690VUPif5cLOnCR1ZWlkpLS1VXVxfdFolEVFdXJ5/Pl8DK7DPGaNWqVdq5c6f27dun4uLimP2lpaXKzMyM6VV7e7uOHz+eEr26/fbb9dprr6mlpSU6ysrKtHTp0uifU7k/c+bMueTW7DfffFM33HCDJKm4uFgejyemP+FwWA0NDSnRnw8//FDp6bG/AseMGaNIJCKJ/lxoIL3w+Xzq6upSc3Nz9Jh9+/YpEomovLzces22fRI8jh07pn/84x8qKCiI2Z/q/blEoq94vZzt27cbh8Nhtm7datra2sz9999v8vLyTCAQSHRpVq1cudK4XC7z8ssvm5MnT0bHhx9+GD1mxYoVxuv1mn379pmmpibj8/mMz+dLYNWJdeHdLsakdn8aGxtNRkaG2bBhgzl27Jh55plnzHXXXWf+9Kc/RY/ZuHGjycvLM88//7z5z3/+YxYsWDBqbyW9WGVlpfn0pz8dvdX2ueeeM+PHjzcPPfRQ9JhU6k93d7c5cuSIOXLkiJFkfvWrX5kjR45E79YYSC/mzZtnZsyYYRoaGsyhQ4fMlClTRs2tpP31p7e319x1111m4sSJpqWlJeb3dU9PT/QxRnN/4pWU4cMYY379618br9drsrKyzKxZs8wrr7yS6JKsk3TZsWXLlugxH330kfnBD35gxo0bZ6677jrzzW9+05w8eTJxRSfYxeEj1fvzwgsvmKlTpxqHw2FKSkrMb37zm5j9kUjErFu3zrjdbuNwOMztt99u2tvbE1StXeFw2Kxevdp4vV6TnZ1tPvOZz5hHH3005sUilfqzf//+y/6+qaysNMYMrBf//e9/zZIlS0xubq5xOp3m3nvvNd3d3Qn4boZef/3p6Oi44u/r/fv3Rx9jNPcnXmnGXPB2fgAAAMMs6a75AAAAoxvhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFb9D9oAs+8UV1pFAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0.0011, norm: 0.0024, param_norm: 936.5659, vb_loss: 0.0005, ce_loss: 0.4201:  38%|███▊      | 598/1563 [00:44<00:38, 24.82it/s]"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWY0lEQVR4nO3df2yV5f3/8VdL6Slazikt6Tl0cKSbZOiAjRUoR8x0ehZgRmGQTQgb1ZEZWHHUJhPRwRI3VpYlE7cgZr9gy2Q4EsHJBoQVhGFKgY46EK0YmTTCOcyxnlNQWuy5Pn98v97pKVDOKe19nx/PR3Il7X1fPefd99X79J3run/kGGOMAAAAbJLrdAAAACC7UHwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbDVjxsW7dOo0ePVoFBQWqrKzUoUOHBuqtAABAGskZiGe7vPjii1q4cKGef/55VVZWau3atdqyZYtaWlpUWlra68/GYjGdOXNGQ4cOVU5OTn+HBgAABoAxRu3t7SorK1Nu7nXmNswAmDJliqmurra+7+rqMmVlZaauru66P9va2mok0Wg0Go1GS8PW2tp63f/1/b7s0tnZqaamJgWDQWtbbm6ugsGgGhoarujf0dGhaDRqNcNDdgEASFtDhw69bp9+Lz4++OADdXV1yev1xm33er0KhUJX9K+rq5PH47Ga3+/v75AAAIBNEjllwvGrXVasWKFIJGK11tZWp0MCAAADKK+/X3D48OEaNGiQwuFw3PZwOCyfz3dFf5fLJZfL1d9hAACAFNXvMx/5+fmqqKhQfX29tS0Wi6m+vl6BQKC/3w4AAKSZfp/5kKTa2lpVVVVp0qRJmjJlitauXauLFy/q4YcfHoi3AwAAaWRAio8HH3xQ//nPf7Rq1SqFQiF94Qtf0M6dO684CRUAAGSfAbnJ2I2IRqPyeDxOhwEAAPogEonI7Xb32sfxq10AAEB2ofgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ynM6ACTvf//7X8J9a2pqrK/Xrl2b8M8NGzYsiYgwEOwY554Yd/v1dZylvo8142w/u8Y5XcaWmQ8AAGArig8AAGCrpIuP/fv36/7771dZWZlycnK0bdu2uP3GGK1atUojRozQkCFDFAwGdfLkyf6KFwAApLkcY4xJ5gd27Nih1157TRUVFZozZ462bt2q2bNnW/t/+tOfqq6uTr///e9VXl6ulStX6tixYzpx4oQKCgqu+/rRaFQejyfpXyRVdF/X67n2lsyaX6pJl3VEu1RVVV1z342cc+E0xjlxHM+Zo+dY9te5NU5IhbGNRCJyu9299kn6hNOZM2dq5syZV91njNHatWv1gx/8QLNmzZIk/eEPf5DX69W2bds0b968ZN8OAABkmH495+PUqVMKhUIKBoPWNo/Ho8rKSjU0NFz1Zzo6OhSNRuMaAADIXP1afIRCIUmS1+uN2+71eq19PdXV1cnj8Vht1KhR/RkSAABIMY7f52PFihWqra21vo9GoylfgCS61uvEmnAqrPdlilRe02ecB07383mcXutnnAdOqp23lehYp/LnUjL6debD5/NJksLhcNz2cDhs7evJ5XLJ7XbHNQAAkLn6tfgoLy+Xz+dTfX29tS0ajaqxsVGBQKA/3woAAKSppJddLly4oHfeecf6/tSpU2publZxcbH8fr9qamr04x//WGPGjLEutS0rK4u7HBfX130KLlOm2XAlxjn1DMSUeyZddp8p0nWcM2UpLuni48iRI/ryl79sff/J+RpVVVXauHGjHn/8cV28eFGPPPKI2tradOedd2rnzp0J3eMDAABkvqSLj7vvvlu93ZcsJydHTz/9tJ5++ukbCgwAAGQmnu0CAABslfTt1QdaJt1ePRmZso6XLRjnzNXb2CZ6jg7jnN6SOb4Z6yslcnt1Zj4AAICtKD4AAICtKD4AAICtHL+9eqbpbf2Pa/szB+OcuRJdw+feHZmL43vgMfMBAABsRfEBAABsxbJLiug5lcflW5mJcc4OjHP26D7WjHPimPkAAAC2ovgAAAC2ovgAAAC24pwPG/H49Ozg9CWYnG/gDI7v7OD08Z0pmPkAAAC2ovgAAAC2YtklRfU2lTd79mzr63379tkQDQbKQDw9k2WW5PTXMlVvY8lTUrMD45w4Zj4AAICtKD4AAICtKD4AAICtcowxxukguotGo/J4PE6HYbuBulwr29cVU01/jTPj2n/6a53ejksuGffUMxDjnu7jHIlE5Ha7e+3DzAcAALAVxQcAALAVxQcAALAV53ykgRtZU+SeIOmDewSkhkTHYaDGoLdHtHPr/IExUHnt6zGd7uPMOR8AACDlJFV81NXVafLkyRo6dKhKS0s1e/ZstbS0xPW5dOmSqqurVVJSosLCQs2dO1fhcLhfgwYAAOkrqWWXGTNmaN68eZo8ebI+/vhjPfnkkzp+/LhOnDihm2++WZK0ZMkS/fWvf9XGjRvl8Xi0dOlS5ebm6rXXXkvoPVh2SQ5T9dmjt7FmbAeG0zm/3vHNuPcPJ/KcyZ/diSy7JPVsl507d8Z9v3HjRpWWlqqpqUlf+tKXFIlE9Nvf/labNm3SPffcI0nasGGDbrvtNh08eFBTp05N8lcAAACZ5obO+YhEIpKk4uJiSVJTU5MuX76sYDBo9Rk7dqz8fr8aGhqu+hodHR2KRqNxDQAAZK4+Fx+xWEw1NTWaNm2axo0bJ0kKhULKz89XUVFRXF+v16tQKHTV16mrq5PH47HaqFGj+hoSAABIA0ktu3RXXV2t48eP68CBAzcUwIoVK1RbW2t9H41GKUCAJPV2eSb6zolc2nGbdsTjmLFfn4qPpUuXavv27dq/f79Gjhxpbff5fOrs7FRbW1vc7Ec4HJbP57vqa7lcLrlcrr6EAQAA0lBSyy7GGC1dulRbt27Vnj17VF5eHre/oqJCgwcPVn19vbWtpaVFp0+fViAQ6J+IAQBAWktq5qO6ulqbNm3Syy+/rKFDh1rncXg8Hg0ZMkQej0eLFi1SbW2tiouL5Xa79eijjyoQCHClCxzBpYpwwvXuUMnSSv/g+E5fSRUf69evlyTdfffdcds3bNighx56SJL0zDPPKDc3V3PnzlVHR4emT5+u5557rl+CBQAA6S+p4iOR+5EVFBRo3bp1WrduXZ+DAgAAmYtnuwAAAFv1+VLbdJfqTw286667rK+3bdvmXCBpLhXW2quqqqyv//3vf1+zH+Oc3nr72+p+PN+IVPuccloqXArdPYae48wxfW3MfAAAAFtRfAAAAFtRfAAAAFvlmEQuYbFRNBqVx+NxOowB0X3tv6e1a9f2y3uwJuw8xjk79Bzn7ufzDNRaP+PuDI7p5EQiEbnd7l77MPMBAABsRfEBAABslbWX2tqh52VXfZ2ey6TpuEw1EJfwMu6pp/v0O9PtmavnMkt/jDXjHI+ZDwAAYCuKDwAAYCuKDwAAYCvO+egDLrvKDoxzdmCcswPjnFqY+QAAALai+AAAALZi2eUq7HjyKdNzqYGxzkxOPL2YcXYex3P6YOYDAADYiuIDAADYiuIDAADYinM+/j/WCrODE+cCABgYHM/pi5kPAABgK4oPAABgK4oPAABgqxxjjHE6iO6i0ag8Ho/TYQAAgD6IRCJyu9299mHmAwAA2Cqp4mP9+vWaMGGC3G633G63AoGAduzYYe2/dOmSqqurVVJSosLCQs2dO1fhcLjfgwYAAOkrqeJj5MiRWrNmjZqamnTkyBHdc889mjVrlt544w1J0mOPPaZXXnlFW7Zs0b59+3TmzBnNmTNnQAIHAABpytygYcOGmd/85jemra3NDB482GzZssXa9+abbxpJpqGhIeHXi0QiRhKNRqPRaLQ0bJFI5Lr/6/t8zkdXV5c2b96sixcvKhAIqKmpSZcvX1YwGLT6jB07Vn6/Xw0NDdd8nY6ODkWj0bgGAAAyV9LFx7Fjx1RYWCiXy6XFixdr69atuv322xUKhZSfn6+ioqK4/l6vV6FQ6JqvV1dXJ4/HY7VRo0Yl/UsAAID0kXTx8dnPflbNzc1qbGzUkiVLVFVVpRMnTvQ5gBUrVigSiVittbW1z68FAABSX9LPdsnPz9ett94qSaqoqNDhw4f17LPP6sEHH1RnZ6fa2triZj/C4bB8Pt81X8/lcsnlciUfOQAASEs3fJ+PWCymjo4OVVRUaPDgwaqvr7f2tbS06PTp0woEAjf6NgAAIEMkNfOxYsUKzZw5U36/X+3t7dq0aZNeffVV7dq1Sx6PR4sWLVJtba2Ki4vldrv16KOPKhAIaOrUqQMVPwAASDNJFR/nzp3TwoULdfbsWXk8Hk2YMEG7du3SV77yFUnSM888o9zcXM2dO1cdHR2aPn26nnvuuQEJHAAApCee7QIAAPoNz3YBAAAph+IDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYKuWKjxR71AwAAEhCIv/HU674aG9vdzoEAADQR4n8H0+5p9rGYjGdOXNGxhj5/X61trZe9+l42SgajWrUqFHk5xrIT+/IT+/IT+/Iz7Vlc26MMWpvb1dZWZlyc3uf28izKaaE5ebmauTIkYpGo5Ikt9uddQOYDPLTO/LTO/LTO/LTO/JzbdmaG4/Hk1C/lFt2AQAAmY3iAwAA2Cpliw+Xy6Uf/vCHcrlcToeSkshP78hP78hP78hP78jPtZGbxKTcCacAACCzpezMBwAAyEwUHwAAwFYUHwAAwFYUHwAAwFYUHwAAwFYpW3ysW7dOo0ePVkFBgSorK3Xo0CGnQ7JdXV2dJk+erKFDh6q0tFSzZ89WS0tLXJ9Lly6purpaJSUlKiws1Ny5cxUOhx2K2Flr1qxRTk6OampqrG3Znp/3339f3/zmN1VSUqIhQ4Zo/PjxOnLkiLXfGKNVq1ZpxIgRGjJkiILBoE6ePOlgxPbp6urSypUrVV5eriFDhugzn/mMfvSjH8U9FCub8rN//37df//9KisrU05OjrZt2xa3P5FcnD9/XgsWLJDb7VZRUZEWLVqkCxcu2PhbDJze8nP58mUtX75c48eP180336yysjItXLhQZ86ciXuNTM5P0kwK2rx5s8nPzze/+93vzBtvvGG+853vmKKiIhMOh50OzVbTp083GzZsMMePHzfNzc3mq1/9qvH7/ebChQtWn8WLF5tRo0aZ+vp6c+TIETN16lRzxx13OBi1Mw4dOmRGjx5tJkyYYJYtW2Ztz+b8nD9/3txyyy3moYceMo2Njebdd981u3btMu+8847VZ82aNcbj8Zht27aZ119/3TzwwAOmvLzcfPTRRw5Gbo/Vq1ebkpISs337dnPq1CmzZcsWU1hYaJ599lmrTzbl529/+5t56qmnzEsvvWQkma1bt8btTyQXM2bMMJ///OfNwYMHzT/+8Q9z6623mvnz59v8mwyM3vLT1tZmgsGgefHFF81bb71lGhoazJQpU0xFRUXca2RyfpKVksXHlClTTHV1tfV9V1eXKSsrM3V1dQ5G5bxz584ZSWbfvn3GmP/3Bz948GCzZcsWq8+bb75pJJmGhganwrRde3u7GTNmjNm9e7e56667rOIj2/OzfPlyc+edd15zfywWMz6fz/zsZz+ztrW1tRmXy2X+9Kc/2RGio+677z7z7W9/O27bnDlzzIIFC4wx2Z2fnv9cE8nFiRMnjCRz+PBhq8+OHTtMTk6Oef/9922L3Q5XK856OnTokJFk3nvvPWNMduUnESm37NLZ2ammpiYFg0FrW25uroLBoBoaGhyMzHmRSESSVFxcLElqamrS5cuX43I1duxY+f3+rMpVdXW17rvvvrg8SOTnL3/5iyZNmqSvf/3rKi0t1cSJE/XrX//a2n/q1CmFQqG4/Hg8HlVWVmZFfu644w7V19fr7bffliS9/vrrOnDggGbOnCmJ/HSXSC4aGhpUVFSkSZMmWX2CwaByc3PV2Nhoe8xOi0QiysnJUVFRkSTy01PKPdX2gw8+UFdXl7xeb9x2r9ert956y6GonBeLxVRTU6Np06Zp3LhxkqRQKKT8/Hzrj/sTXq9XoVDIgSjtt3nzZv3zn//U4cOHr9iX7fl59913tX79etXW1urJJ5/U4cOH9b3vfU/5+fmqqqqycnC1Yy0b8vPEE08oGo1q7NixGjRokLq6urR69WotWLBAkrI+P90lkotQKKTS0tK4/Xl5eSouLs66fF26dEnLly/X/PnzrSfbkp94KVd84Oqqq6t1/PhxHThwwOlQUkZra6uWLVum3bt3q6CgwOlwUk4sFtOkSZP0k5/8RJI0ceJEHT9+XM8//7yqqqocjs55f/7zn/XCCy9o06ZN+tznPqfm5mbV1NSorKyM/KDPLl++rG984xsyxmj9+vVOh5OyUm7ZZfjw4Ro0aNAVVySEw2H5fD6HonLW0qVLtX37du3du1cjR460tvt8PnV2dqqtrS2uf7bkqqmpSefOndMXv/hF5eXlKS8vT/v27dMvfvEL5eXlyev1ZnV+RowYodtvvz1u22233abTp09LkpWDbD3Wvv/97+uJJ57QvHnzNH78eH3rW9/SY489prq6Oknkp7tEcuHz+XTu3Lm4/R9//LHOnz+fNfn6pPB47733tHv3bmvWQyI/PaVc8ZGfn6+KigrV19db22KxmOrr6xUIBByMzH7GGC1dulRbt27Vnj17VF5eHre/oqJCgwcPjstVS0uLTp8+nRW5uvfee3Xs2DE1NzdbbdKkSVqwYIH1dTbnZ9q0aVdcmv3222/rlltukSSVl5fL5/PF5ScajaqxsTEr8vPhhx8qNzf+I3DQoEGKxWKSyE93ieQiEAiora1NTU1NVp89e/YoFoupsrLS9pjt9knhcfLkSf39739XSUlJ3P5sz88VnD7j9Wo2b95sXC6X2bhxozlx4oR55JFHTFFRkQmFQk6HZqslS5YYj8djXn31VXP27Fmrffjhh1afxYsXG7/fb/bs2WOOHDliAoGACQQCDkbtrO5XuxiT3fk5dOiQycvLM6tXrzYnT540L7zwgrnpppvMH//4R6vPmjVrTFFRkXn55ZfNv/71LzNr1qyMvZS0p6qqKvOpT33KutT2pZdeMsOHDzePP/641Seb8tPe3m6OHj1qjh49aiSZn//85+bo0aPW1RqJ5GLGjBlm4sSJprGx0Rw4cMCMGTMmYy4l7S0/nZ2d5oEHHjAjR440zc3NcZ/XHR0d1mtkcn6SlZLFhzHG/PKXvzR+v9/k5+ebKVOmmIMHDzodku0kXbVt2LDB6vPRRx+Z7373u2bYsGHmpptuMl/72tfM2bNnnQvaYT2Lj2zPzyuvvGLGjRtnXC6XGTt2rPnVr34Vtz8Wi5mVK1car9drXC6Xuffee01LS4tD0dorGo2aZcuWGb/fbwoKCsynP/1p89RTT8X9s8im/Ozdu/eqnzdVVVXGmMRy8d///tfMnz/fFBYWGrfbbR5++GHT3t7uwG/T/3rLz6lTp675eb13717rNTI5P8nKMabb7fwAAAAGWMqd8wEAADIbxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALDV/wHgUGBrLgpAzQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0.0006, norm: 0.0004, param_norm: 938.1175, vb_loss: 0.0003, ce_loss: 0.2727:  58%|█████▊    | 900/1563 [01:05<00:27, 24.27it/s]"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVn0lEQVR4nO3df2yV5f3/8VdL29Nqe05pWc+ho0e6SYYOcKzQcsRsbp4FmFEYbBPCRlUyAyuO2mRidbDEjZVsyUQXxWzLYMtkuCaCk00JKwhjqS3tqLMiFWMjjXAOOtZzCkqLPdf3j292Phwopaec3ufX85FcCb3vq/d59331nL657vu+7gxjjBEAAIBFMuMdAAAASC8UHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFJjVnw8/fTTmjx5snJzc1VVVaXW1taxeikAAJBEMsbi2S7PP/+8VqxYoWeffVZVVVXavHmzGhsb1dXVpZKSkmG/NxQK6eTJkyooKFBGRkasQwMAAGPAGKO+vj6VlpYqM/MqcxtmDFRWVpqamprw14ODg6a0tNQ0NDRc9Xt7enqMJBqNRqPRaEnYenp6rvq3PuanXQYGBtTe3i6v1xvelpmZKa/Xq+bm5sv69/f3KxgMhpvhIbsAACStgoKCq/aJefHx4YcfanBwUE6nM2K70+mUz+e7rH9DQ4McDke4ud3uWIcEAAAsMpJLJuJ+t0t9fb0CgUC49fT0xDskAAAwhrJifcAJEyZo3Lhx8vv9Edv9fr9cLtdl/W02m2w2W6zDAAAACSrmMx85OTmqqKhQU1NTeFsoFFJTU5M8Hk+sXw4AACSZmM98SFJdXZ2qq6s1a9YsVVZWavPmzTp37pzuu+++sXg5AACQRMak+Ljnnnv0wQcfaMOGDfL5fPrCF76gV1555bKLUAEAQPoZk0XGrkUwGJTD4Yh3GAAAYBQCgYDsdvuwfeJ+twsAAEgvFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSWfEOANfmv//977D7x48fP+K+V/o+JIbhxo9xTh0jHeer9R3u+xB/YzHOQ31vomLmAwAAWIriAwAAWCrq4uPgwYO66667VFpaqoyMDO3atStivzFGGzZs0MSJE5WXlyev16vjx4/HKl4AAJDkor7m49y5c7rlllt0//33a/HixZft//nPf66nnnpKv//971VeXq7169dr3rx5Onr0qHJzc2MSdCqI5hxeMrwO/k88cp6K54QTXaKPM2LH6rynwzhHXXwsWLBACxYsGHKfMUabN2/Wj370Iy1cuFCS9Ic//EFOp1O7du3S0qVLry1aAACQ9GJ6zUd3d7d8Pp+8Xm94m8PhUFVVlZqbm4f8nv7+fgWDwYgGAABSV0yLD5/PJ0lyOp0R251OZ3jfpRoaGuRwOMKtrKwsliEBAIAEE/d1Purr61VXVxf+OhgMJlUBkmjn5jifb414jzvjbA3GOT3Ee5yl9BvrmM58uFwuSZLf74/Y7vf7w/suZbPZZLfbIxoAAEhdMS0+ysvL5XK51NTUFN4WDAbV0tIij8cTy5cCAABJKurTLmfPntU777wT/rq7u1sdHR0qKiqS2+1WbW2tfvrTn2rKlCnhW21LS0u1aNGiWMadMK5lGdxojjsWr4H4Y5zTA+OcHhjnkYu6+Ghra9NXvvKV8Nf/u16jurpa27Zt08MPP6xz587pgQceUG9vr2677Ta98sorrPEBAAAkjaL4uP3222WMueL+jIwMPf7443r88cevKTAAAJCaeLYLAACwVIYZbhojDoLBoBwOR7zDiAmWvIYU+XvAOKcuxjk9XPq5zlhfLhAIXPXOVWY+AACApSg+AACApSg+AACApeK+vHoqG6s1QJBcOCecHhjn9MA4xwYzHwAAwFIUHwAAwFKcdrHQxdN1nIIBAKQrZj4AAIClKD4AAIClKD4AAICluOYjTq52Gy63c6UmxhkAmPkAAAAWo/gAAACW4rQLAMQYp9fSx0ifZszvRCRmPgAAgKUoPgAAgKUoPgAAgKUyjDEm3kFcLBgMyuFwxDuMuBvt8uvpfh4x2UQzzoxtahrud4AxT27p+v4OBAKy2+3D9mHmAwAAWIriAwAAWIriAwAAWIp1PhJUNPeLI3ldPM5XG9eRrieA5HK1Ry0geUXz/k43zHwAAABLRVV8NDQ0aPbs2SooKFBJSYkWLVqkrq6uiD7nz59XTU2NiouLlZ+fryVLlsjv98c0aAAAkLyiutV2/vz5Wrp0qWbPnq1PPvlEjz76qDo7O3X06FFdf/31kqTVq1frr3/9q7Zt2yaHw6E1a9YoMzNT//znP0f0Gtxqe21Ywjd1cUtm+uH9nLpS+f08klttr2mdjw8++EAlJSU6cOCAvvSlLykQCOhTn/qUtm/frm9+85uSpGPHjummm25Sc3Oz5syZc9VjUnxcGz6sUlcqf1hhaLyfU1cqv5/HfJ2PQCAgSSoqKpIktbe368KFC/J6veE+U6dOldvtVnNz85DH6O/vVzAYjGgAACB1jbr4CIVCqq2t1dy5czVt2jRJks/nU05OjgoLCyP6Op1O+Xy+IY/T0NAgh8MRbmVlZaMNCQAAJIFR32pbU1Ojzs5OHTp06JoCqK+vV11dXfjrYDBIARIFbs+ExDinkuHe05yGSQ/pMM6jKj7WrFmj3bt36+DBg5o0aVJ4u8vl0sDAgHp7eyNmP/x+v1wu15DHstlsstlsowkDAAAkoahOuxhjtGbNGu3cuVP79u1TeXl5xP6KigplZ2erqakpvK2rq0snTpyQx+OJTcQAACCpRTXzUVNTo+3bt+vFF19UQUFB+DoOh8OhvLw8ORwOrVy5UnV1dSoqKpLdbteDDz4oj8czojtdEHupOF0HpCvez0gVURUfW7ZskSTdfvvtEdu3bt2qe++9V5L0xBNPKDMzU0uWLFF/f7/mzZunZ555JibBAgCA5HdN63yMBdb5iM7VLjjlf0qpY6TPhmDMk1sqr/+A/xPNs16SbdzHfJ0PAACAaPFU2yTA0xDTw7WMc7L9zyjd8Z5OD6Md53R4PzPzAQAALEXxAQAALEXxAQAALMU1HxYabglszg2mjuGWRua6jtTBOKcPPrtjj5kPAABgKYoPAABgKU67xEkqLzCDSCwOlh54T6cHxjk2mPkAAACWovgAAACWovgAAACW4pqPGIvVssmcK0xsjHN6YJzTA+NsPWY+AACApSg+AACApTjtMoThVi4cav9oMD2XGMZi5cJLMdaJjXFOHWO16uyVjonRY+YDAABYiuIDAABYiuIDAABYims+RmC4c4Wc/0sdjHPqYkns9MR7OnEx8wEAACxF8QEAACxF8QEAACyVYYwx8Q7iYsFgUA6HI95hAACAUQgEArLb7cP2YeYDAABYKqriY8uWLZoxY4bsdrvsdrs8Ho9efvnl8P7z58+rpqZGxcXFys/P15IlS+T3+2MeNAAASF5RFR+TJk3Spk2b1N7erra2Nn31q1/VwoUL9eabb0qSHnroIb300ktqbGzUgQMHdPLkSS1evHhMAgcAAEnKXKPx48eb3/72t6a3t9dkZ2ebxsbG8L633nrLSDLNzc0jPl4gEDCSaDQajUajJWELBAJX/Vs/6ms+BgcHtWPHDp07d04ej0ft7e26cOGCvF5vuM/UqVPldrvV3Nx8xeP09/crGAxGNAAAkLqiLj7eeOMN5efny2azadWqVdq5c6duvvlm+Xw+5eTkqLCwMKK/0+mUz+e74vEaGhrkcDjCraysLOofAgAAJI+oi4/Pfe5z6ujoUEtLi1avXq3q6modPXp01AHU19crEAiEW09Pz6iPBQAAEl/Uz3bJycnRjTfeKEmqqKjQ4cOH9eSTT+qee+7RwMCAent7I2Y//H6/XC7XFY9ns9lks9mijxwAACSla17nIxQKqb+/XxUVFcrOzlZTU1N4X1dXl06cOCGPx3OtLwMAAFJEVDMf9fX1WrBggdxut/r6+rR9+3a9+uqr2rNnjxwOh1auXKm6ujoVFRXJbrfrwQcflMfj0Zw5c8YqfgAAkGSiKj5Onz6tFStW6NSpU3I4HJoxY4b27Nmjr33ta5KkJ554QpmZmVqyZIn6+/s1b948PfPMM2MSOAAASE482wUAAMQMz3YBAAAJh+IDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYKuGKjwR71AwAAIjCSP6OJ1zx0dfXF+8QAADAKI3k73jCPdU2FArp5MmTMsbI7Xarp6fnqk/HS0fBYFBlZWXk5wrIz/DIz/DIz/DIz5Wlc26MMerr61NpaakyM4ef28iyKKYRy8zM1KRJkxQMBiVJdrs97QYwGuRneORneORneORneOTnytI1Nw6HY0T9Eu60CwAASG0UHwAAwFIJW3zYbDb9+Mc/ls1mi3coCYn8DI/8DI/8DI/8DI/8XBm5GZmEu+AUAACktoSd+QAAAKmJ4gMAAFiK4gMAAFiK4gMAAFiK4gMAAFgqYYuPp59+WpMnT1Zubq6qqqrU2toa75As19DQoNmzZ6ugoEAlJSVatGiRurq6IvqcP39eNTU1Ki4uVn5+vpYsWSK/3x+niONr06ZNysjIUG1tbXhbuufn/fff13e+8x0VFxcrLy9P06dPV1tbW3i/MUYbNmzQxIkTlZeXJ6/Xq+PHj8cxYusMDg5q/fr1Ki8vV15enj772c/qJz/5ScRDsdIpPwcPHtRdd92l0tJSZWRkaNeuXRH7R5KLM2fOaPny5bLb7SosLNTKlSt19uxZC3+KsTNcfi5cuKB169Zp+vTpuv7661VaWqoVK1bo5MmTEcdI5fxEzSSgHTt2mJycHPO73/3OvPnmm+Z73/ueKSwsNH6/P96hWWrevHlm69atprOz03R0dJivf/3rxu12m7Nnz4b7rFq1ypSVlZmmpibT1tZm5syZY2699dY4Rh0fra2tZvLkyWbGjBlm7dq14e3pnJ8zZ86YG264wdx7772mpaXFvPvuu2bPnj3mnXfeCffZtGmTcTgcZteuXeb11183d999tykvLzcff/xxHCO3xsaNG01xcbHZvXu36e7uNo2NjSY/P988+eST4T7plJ+//e1v5rHHHjMvvPCCkWR27twZsX8kuZg/f7655ZZbzGuvvWb+8Y9/mBtvvNEsW7bM4p9kbAyXn97eXuP1es3zzz9vjh07Zpqbm01lZaWpqKiIOEYq5ydaCVl8VFZWmpqamvDXg4ODprS01DQ0NMQxqvg7ffq0kWQOHDhgjPn/v/DZ2dmmsbEx3Oett94ykkxzc3O8wrRcX1+fmTJlitm7d6/58pe/HC4+0j0/69atM7fddtsV94dCIeNyucwvfvGL8Lbe3l5js9nMn/70JytCjKs777zT3H///RHbFi9ebJYvX26MSe/8XPrHdSS5OHr0qJFkDh8+HO7z8ssvm4yMDPP+++9bFrsVhirOLtXa2mokmffee88Yk175GYmEO+0yMDCg9vZ2eb3e8LbMzEx5vV41NzfHMbL4CwQCkqSioiJJUnt7uy5cuBCRq6lTp8rtdqdVrmpqanTnnXdG5EEiP3/5y180a9Ysfetb31JJSYlmzpyp3/zmN+H93d3d8vl8EflxOByqqqpKi/zceuutampq0ttvvy1Jev3113Xo0CEtWLBAEvm52Ehy0dzcrMLCQs2aNSvcx+v1KjMzUy0tLZbHHG+BQEAZGRkqLCyURH4ulXBPtf3www81ODgop9MZsd3pdOrYsWNxiir+QqGQamtrNXfuXE2bNk2S5PP5lJOTE/7l/h+n0ymfzxeHKK23Y8cO/etf/9Lhw4cv25fu+Xn33Xe1ZcsW1dXV6dFHH9Xhw4f1gx/8QDk5Oaqurg7nYKj3Wjrk55FHHlEwGNTUqVM1btw4DQ4OauPGjVq+fLkkpX1+LjaSXPh8PpWUlETsz8rKUlFRUdrl6/z581q3bp2WLVsWfrIt+YmUcMUHhlZTU6POzk4dOnQo3qEkjJ6eHq1du1Z79+5Vbm5uvMNJOKFQSLNmzdLPfvYzSdLMmTPV2dmpZ599VtXV1XGOLv7+/Oc/67nnntP27dv1+c9/Xh0dHaqtrVVpaSn5wahduHBB3/72t2WM0ZYtW+IdTsJKuNMuEyZM0Lhx4y67I8Hv98vlcsUpqvhas2aNdu/erf3792vSpEnh7S6XSwMDA+rt7Y3ony65am9v1+nTp/XFL35RWVlZysrK0oEDB/TUU08pKytLTqczrfMzceJE3XzzzRHbbrrpJp04cUKSwjlI1/faD3/4Qz3yyCNaunSppk+fru9+97t66KGH1NDQIIn8XGwkuXC5XDp9+nTE/k8++URnzpxJm3z9r/B47733tHfv3vCsh0R+LpVwxUdOTo4qKirU1NQU3hYKhdTU1CSPxxPHyKxnjNGaNWu0c+dO7du3T+Xl5RH7KyoqlJ2dHZGrrq4unThxIi1ydccdd+iNN95QR0dHuM2aNUvLly8P/zud8zN37tzLbs1+++23dcMNN0iSysvL5XK5IvITDAbV0tKSFvn56KOPlJkZ+RE4btw4hUIhSeTnYiPJhcfjUW9vr9rb28N99u3bp1AopKqqKstjttr/Co/jx4/r73//u4qLiyP2p3t+LhPvK16HsmPHDmOz2cy2bdvM0aNHzQMPPGAKCwuNz+eLd2iWWr16tXE4HObVV181p06dCrePPvoo3GfVqlXG7Xabffv2mba2NuPxeIzH44lj1PF18d0uxqR3flpbW01WVpbZuHGjOX78uHnuuefMddddZ/74xz+G+2zatMkUFhaaF1980fz73/82CxcuTNlbSS9VXV1tPv3pT4dvtX3hhRfMhAkTzMMPPxzuk0756evrM0eOHDFHjhwxkswvf/lLc+TIkfDdGiPJxfz5883MmTNNS0uLOXTokJkyZUrK3Eo6XH4GBgbM3XffbSZNmmQ6OjoiPq/7+/vDx0jl/EQrIYsPY4z51a9+Zdxut8nJyTGVlZXmtddei3dIlpM0ZNu6dWu4z8cff2y+//3vm/Hjx5vrrrvOfOMb3zCnTp2KX9Bxdmnxke75eemll8y0adOMzWYzU6dONb/+9a8j9odCIbN+/XrjdDqNzWYzd9xxh+nq6opTtNYKBoNm7dq1xu12m9zcXPOZz3zGPPbYYxF/LNIpP/v37x/y86a6utoYM7Jc/Oc//zHLli0z+fn5xm63m/vuu8/09fXF4aeJveHy093dfcXP6/3794ePkcr5iVaGMRct5wcAADDGEu6aDwAAkNooPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKX+H9ZtKdLUYoywAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0.0004, norm: 0.0002, param_norm: 938.9938, vb_loss: 0.0002, ce_loss: 0.1951:  77%|███████▋  | 1200/1563 [01:26<00:14, 25.02it/s]"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYT0lEQVR4nO3df0xV9/3H8ReIXGyRi2AAmVLZ2sx2/qhDRWoz+4NFXdP6K2s1bqI1a3TYSUlWazud6eJwW1Kxi7XZL+iyOp1JwdWtGoe/5oaoTNpaW2pTpqR6cZ2Fi7aChc/3j317cs9Vr1y8nHsv9/lITnJ+ce6H95tzeeecz+ecOGOMEQAAgEPiw90AAAAQWyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAo/qs+Ni0aZNGjhyppKQk5efn68iRI331UQAAIIrE9cW7XbZt26aFCxfq5ZdfVn5+vsrLy7V9+3Y1NjYqIyMj4M92d3fr7NmzGjx4sOLi4kLdNAAA0AeMMWpvb1d2drbi429wbcP0gUmTJpni4mJruaury2RnZ5uysrIb/mxzc7ORxMTExMTExBSFU3Nz8w3/14f8tktnZ6fq6+tVWFhorYuPj1dhYaFqa2uv2r+jo0Ner9eaDC/ZBQAgag0ePPiG+4S8+Pj444/V1dWlzMxM2/rMzEx5PJ6r9i8rK5Pb7bamnJycUDcJAAA4pCddJsI+2mXVqlVqa2uzpubm5nA3CQAA9KGEUB9w6NChGjBggFpaWmzrW1palJWVddX+LpdLLpcr1M0AAAARKuRXPhITE5WXl6eamhprXXd3t2pqalRQUBDqjwMAAFEm5Fc+JKm0tFRFRUWaMGGCJk2apPLycl26dEmLFy/ui48DAABRpE+Kj8cee0z/+c9/tGbNGnk8Ht19993atWvXVZ1QAQBA7OmTh4zdDK/XK7fbHe5mAACAXmhra1NKSkrAfcI+2gUAAMQWig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOCohHA3AKFVUVFhW66srLTmFy1aZNs2a9as6x5n7dq1tuWNGzfeZMsQSuQ5NgTKs2TPNXmObr657m2eJXuuIznPXPkAAACOovgAAACOCrr4OHjwoB5++GFlZ2crLi5O1dXVtu3GGK1Zs0bDhg3ToEGDVFhYqFOnToWqvQAAIMoF3efj0qVLGjdunB5//HHNmTPnqu0///nP9eKLL+qVV15Rbm6uVq9erWnTpunkyZNKSkoKSaOjlf/9W183uo/XU/v377ct9/ReYUNDQ8Bl9Jx/nu+77z5rPjU1NSSf0ds8S/bckufeC/f5fKPPIc+9Fyi3/voi18Hk2V+05Dro4mPGjBmaMWPGNbcZY1ReXq4f/ehHmjlzpiTp97//vTIzM1VdXa158+bdXGsBAEDUC2mfj6amJnk8HhUWFlrr3G638vPzVVtbe82f6ejokNfrtU0AAKD/Cmnx4fF4JEmZmZm29ZmZmdY2f2VlZXK73dY0YsSIUDYJAABEmDhjjOn1D8fFqaqqyrof9c9//lNTpkzR2bNnNWzYMGu/Rx99VHFxcdq2bdtVx+jo6FBHR4e17PV6o6oAaWpqsi2H6p5+b/nfKwxkx44d1vxtt91m23b69OlQNalf6C95luy5Js/SuHHjbMu+nejJc//le06HO89Sz3MdKM9SZOS6ra1NKSkpAfcJ6ZWPrKwsSVJLS4ttfUtLi7XNn8vlUkpKim0CAAD9V0iLj9zcXGVlZammpsZa5/V6VVdXp4KCglB+FAAAiFJB33a5ePGiPvjgA0nS+PHj9cILL+j+++9XWlqacnJy9LOf/Uzr16+3DbV96623ejzU1uv1yu129+63CQP/S16hGubke5zW1lbbNt+hm/6GDBkSks+H3Rejt77g//jj3iLP4Td16lTb8r///W9rvi/OZ8mea/Ic+T755JMe7Xf33XfblsvLy635QHmW+leue3LbJeihtseOHdP9999vLZeWlkqSioqKVFlZqaefflqXLl3SE088odbWVt17773atWtXzD/jAwAA/E/Qxcd9992nQBdL4uLi9Pzzz+v555+/qYYBAID+iXe7AAAARwV95QN2/sOafPsC+A/fCubx5r63tvz59jMpKSm5URMRAv7D2/zvz/rmJFA/gd7mWSLXfeXAgQPX3eZ/zvb2dQXkObr5nu/+/T98c+3//2D27NnWPHm248oHAABwFMUHAABwFMUHAABw1E09Xr0vRNtzPoJRVVVlzfuP+fZ/tK5/HwNED988S/Zck+f+gzwD1+b449UBAABuhOIDAAA4itsuYeI/XGv//v22Zd8hWug//C/Vk+f+w/ec5nxGLOO2CwAAiDgUHwAAwFEUHwAAwFE8Xj1M/O8J+w+99X3Nd6DHPyPy+fbz8L/3T577D99zOtD5LJHraBao3xZ57jmufAAAAEdRfAAAAEcx1DZM/N9w6P8WTN9LuK2trdfd9sorr4S4ZQi1QG+8DZRnf4sXLw5hqxBqPc2zZM/12rVrbdv834yKyFJRUWFb9n27MXn+H4baAgCAiEPxAQAAHEXxAQAAHEWfjwgxc+ZM23JlZWWPfq6kpMS2XF1dbVtua2u7iVYhnAI9gp/HdUcX/1wOGTLkutvIc/QKJs++2/ob+nwAAICIQ/EBAAAcRfEBAAAcRZ+PCOX7zIDy8nLbNv9HN/vyHXMu8Xjf/sT3nrH//WLfc4Z+PkB4+D+/qbfP8gjVccKFPh8AACDiBFV8lJWVaeLEiRo8eLAyMjI0a9YsNTY22va5fPmyiouLlZ6eruTkZM2dO1ctLS0hbTQAAIheQd12mT59uubNm6eJEyfq888/17PPPqsTJ07o5MmTuvXWWyVJy5Yt01/+8hdVVlbK7XZr+fLlio+P1z/+8Y8efQa3Xa7mHw/fy+r+w7f8MWyvf4qlYXv9EbfJ0J/15LZLQjAH3LVrl225srJSGRkZqq+v1ze+8Q21tbXpt7/9rbZs2aIHHnhA0v+eg3/nnXfq8OHDmjx5cpC/AgAA6G9uqs/HFxV7WlqaJKm+vl5XrlxRYWGhtc+oUaOUk5Oj2traax6jo6NDXq/XNgEAgP6r18VHd3e3SkpKNGXKFI0ePVqS5PF4lJiYqNTUVNu+mZmZ8ng81zxOWVmZ3G63NY0YMaK3TQIAAFEgqNsuvoqLi3XixAkdOnTophqwatUqlZaWWster5cCxM/N3BP2HZY7depU2zaG4UaXQP17fLcx3DryBTqnb9SPyxd9faJXrD+KvVfFx/Lly7Vz504dPHhQw4cPt9ZnZWWps7NTra2ttqsfLS0tysrKuuaxXC6XXC5Xb5oBAACiUFC3XYwxWr58uaqqqrR3717l5ubatufl5WngwIGqqamx1jU2NurMmTMqKCgITYsBAEBUC+rKR3FxsbZs2aIdO3Zo8ODBVj8Ot9utQYMGye12a8mSJSotLVVaWppSUlL05JNPqqCggJEuIVRVVdXjfRctWmTNc/k9ugST54aGBmuePEeX3p7PiC43yrPvrZZYyHNQxcfmzZslXf1474qKCitYGzZsUHx8vObOnauOjg5NmzZNL730UkgaCwAAol9QxUdPnkeWlJSkTZs2adOmTb1uFAAA6L94twsAAHBUr4fa4ub4v7Vw5MiRtuWSkhJrPtBbbG/E997hjh07en0chIb/cOdQ5bm1tbXXP4vQC5Rnqfe55hyOPL65rqystG3zf+ZVT/n3+eiPeefKBwAAcBTFBwAAcBTFBwAAcBR9PhzU1NRkzfvf06uuru7VMX2f7yBJa9eutS3zzAfn+eZZsue6t3n2xyPUw6+v8uzbbyBUfy+4OU5/d8fC+cyVDwAA4CiKDwAA4Kg405MnhznI6/XK7XaHuxm95j/EzleoLs+Vl5db8/1xCFY0cDrPErkOh77Is2TPNXkOv0B5lkJzTsdSntva2pSSkhJwH658AAAAR1F8AAAAR1F8AAAARzHUthcqKiqsef/H6dKvo/8gz7HBN89SaIa60n8nMmzYsMGa93/UeaiGNMdqv46bxZUPAADgKIoPAADgKIba/r+qqipr3v8Nof6X627m7aNf2L9/v2159uzZN31M3Jj/Uyl989AXefb/DPLsjHHjxtmWfd8qS577D9/vbSnwd3df5Fki19fCUFsAABBxKD4AAICjKD4AAICj6PNxDZ988kmP9/V/u6jvvWV/3BuMPD3NdTB59n/rZVtbW5CtQqj19pzmfI4ufZFnyX5Ocz7fGH0+AABAxKH4AAAAjqL4AAAAjqLPBwAACBn6fAAAgIgTVPGxefNmjR07VikpKUpJSVFBQYHeeOMNa/vly5dVXFys9PR0JScna+7cuWppaQl5owEAQPQKqvgYPny41q9fr/r6eh07dkwPPPCAZs6cqXfeeUeS9NRTT+n111/X9u3bdeDAAZ09e1Zz5szpk4YDAIAoZW7SkCFDzG9+8xvT2tpqBg4caLZv325te/fdd40kU1tb2+PjtbW1GUlMTExMTExMUTi1tbXd8H99r/t8dHV1aevWrbp06ZIKCgpUX1+vK1euqLCw0Npn1KhRysnJUW1t7XWP09HRIa/Xa5sAAED/FXTx8fbbbys5OVkul0tLly5VVVWV7rrrLnk8HiUmJl71xsjMzEx5PJ7rHq+srExut9uaRowYEfQvAQAAokfQxcdXv/pVNTQ0qK6uTsuWLVNRUZFOnjzZ6wasWrVKbW1t1tTc3NzrYwEAgMiXEOwPJCYm6vbbb5ck5eXl6ejRo9q4caMee+wxdXZ2qrW11Xb1o6WlRVlZWdc9nsvlksvlCr7lAAAgKt30cz66u7vV0dGhvLw8DRw4UDU1Nda2xsZGnTlzRgUFBTf7MQAAoJ8I6srHqlWrNGPGDOXk5Ki9vV1btmzR/v37tXv3brndbi1ZskSlpaVKS0tTSkqKnnzySRUUFGjy5Ml91X4AABBlgio+zp8/r4ULF+rcuXNyu90aO3asdu/erW9+85uSpA0bNig+Pl5z585VR0eHpk2bppdeeqlPGg4AAKIT73YBAAAhw7tdAABAxKH4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjoq44iPCXjUDAACC0JP/4xFXfLS3t4e7CQAAoJd68n884t5q293drbNnz8oYo5ycHDU3N9/w7XixyOv1asSIEcTnOohPYMQnMOITGPG5vliOjTFG7e3tys7OVnx84GsbCQ61qcfi4+M1fPhweb1eSVJKSkrMJTAYxCcw4hMY8QmM+ARGfK4vVmPjdrt7tF/E3XYBAAD9G8UHAABwVMQWHy6XSz/+8Y/lcrnC3ZSIRHwCIz6BEZ/AiE9gxOf6iE3PRFyHUwAA0L9F7JUPAADQP1F8AAAAR1F8AAAAR1F8AAAAR1F8AAAAR0Vs8bFp0yaNHDlSSUlJys/P15EjR8LdJMeVlZVp4sSJGjx4sDIyMjRr1iw1Njba9rl8+bKKi4uVnp6u5ORkzZ07Vy0tLWFqcXitX79ecXFxKikpsdbFenw++ugjfec731F6eroGDRqkMWPG6NixY9Z2Y4zWrFmjYcOGadCgQSosLNSpU6fC2GLndHV1afXq1crNzdWgQYP0la98RT/5yU9sL8WKpfgcPHhQDz/8sLKzsxUXF6fq6mrb9p7E4sKFC1qwYIFSUlKUmpqqJUuW6OLFiw7+Fn0nUHyuXLmilStXasyYMbr11luVnZ2thQsX6uzZs7Zj9Of4BM1EoK1bt5rExETzu9/9zrzzzjvme9/7nklNTTUtLS3hbpqjpk2bZioqKsyJEydMQ0OD+da3vmVycnLMxYsXrX2WLl1qRowYYWpqasyxY8fM5MmTzT333BPGVofHkSNHzMiRI83YsWPNihUrrPWxHJ8LFy6Y2267zSxatMjU1dWZDz/80Ozevdt88MEH1j7r1683brfbVFdXmzfffNM88sgjJjc313z22WdhbLkz1q1bZ9LT083OnTtNU1OT2b59u0lOTjYbN2609oml+Pz1r381zz33nHnttdeMJFNVVWXb3pNYTJ8+3YwbN84cPnzY/P3vfze33367mT9/vsO/Sd8IFJ/W1lZTWFhotm3bZt577z1TW1trJk2aZPLy8mzH6M/xCVZEFh+TJk0yxcXF1nJXV5fJzs42ZWVlYWxV+J0/f95IMgcOHDDG/O8PfuDAgWb79u3WPu+++66RZGpra8PVTMe1t7ebO+64w+zZs8dMnTrVKj5iPT4rV640995773W3d3d3m6ysLPOLX/zCWtfa2mpcLpf54x//6EQTw+qhhx4yjz/+uG3dnDlzzIIFC4wxsR0f/3+uPYnFyZMnjSRz9OhRa5833njDxMXFmY8++sixtjvhWsWZvyNHjhhJ5vTp08aY2IpPT0TcbZfOzk7V19ersLDQWhcfH6/CwkLV1taGsWXh19bWJklKS0uTJNXX1+vKlSu2WI0aNUo5OTkxFavi4mI99NBDtjhIxOfPf/6zJkyYoG9/+9vKyMjQ+PHj9etf/9ra3tTUJI/HY4uP2+1Wfn5+TMTnnnvuUU1Njd5//31J0ptvvqlDhw5pxowZkoiPr57Eora2VqmpqZowYYK1T2FhoeLj41VXV+d4m8Otra1NcXFxSk1NlUR8/EXcW20//vhjdXV1KTMz07Y+MzNT7733XphaFX7d3d0qKSnRlClTNHr0aEmSx+NRYmKi9cf9hczMTHk8njC00nlbt27Vv/71Lx09evSqbbEenw8//FCbN29WaWmpnn32WR09elQ/+MEPlJiYqKKiIisG1zrXYiE+zzzzjLxer0aNGqUBAwaoq6tL69at04IFCyQp5uPjqyex8Hg8ysjIsG1PSEhQWlpazMXr8uXLWrlypebPn2+92Zb42EVc8YFrKy4u1okTJ3To0KFwNyViNDc3a8WKFdqzZ4+SkpLC3ZyI093drQkTJuinP/2pJGn8+PE6ceKEXn75ZRUVFYW5deH3pz/9Sa+++qq2bNmir33ta2poaFBJSYmys7OJD3rtypUrevTRR2WM0ebNm8PdnIgVcbddhg4dqgEDBlw1IqGlpUVZWVlhalV4LV++XDt37tS+ffs0fPhwa31WVpY6OzvV2tpq2z9WYlVfX6/z58/r61//uhISEpSQkKADBw7oxRdfVEJCgjIzM2M6PsOGDdNdd91lW3fnnXfqzJkzkmTFIFbPtR/+8Id65plnNG/ePI0ZM0bf/e539dRTT6msrEwS8fHVk1hkZWXp/Pnztu2ff/65Lly4EDPx+qLwOH36tPbs2WNd9ZCIj7+IKz4SExOVl5enmpoaa113d7dqampUUFAQxpY5zxij5cuXq6qqSnv37lVubq5te15engYOHGiLVWNjo86cORMTsXrwwQf19ttvq6GhwZomTJigBQsWWPOxHJ8pU6ZcNTT7/fff12233SZJys3NVVZWli0+Xq9XdXV1MRGfTz/9VPHx9q/AAQMGqLu7WxLx8dWTWBQUFKi1tVX19fXWPnv37lV3d7fy8/Mdb7PTvig8Tp06pb/97W9KT0+3bY/1+Fwl3D1er2Xr1q3G5XKZyspKc/LkSfPEE0+Y1NRU4/F4wt00Ry1btsy43W6zf/9+c+7cOWv69NNPrX2WLl1qcnJyzN69e82xY8dMQUGBKSgoCGOrw8t3tIsxsR2fI0eOmISEBLNu3Tpz6tQp8+qrr5pbbrnF/OEPf7D2Wb9+vUlNTTU7duwwb731lpk5c2a/HUrqr6ioyHzpS1+yhtq+9tprZujQoebpp5+29oml+LS3t5vjx4+b48ePG0nmhRdeMMePH7dGa/QkFtOnTzfjx483dXV15tChQ+aOO+7oN0NJA8Wns7PTPPLII2b48OGmoaHB9n3d0dFhHaM/xydYEVl8GGPML3/5S5OTk2MSExPNpEmTzOHDh8PdJMdJuuZUUVFh7fPZZ5+Z73//+2bIkCHmlltuMbNnzzbnzp0LX6PDzL/4iPX4vP7662b06NHG5XKZUaNGmV/96le27d3d3Wb16tUmMzPTuFwu8+CDD5rGxsYwtdZZXq/XrFixwuTk5JikpCTz5S9/2Tz33HO2fxaxFJ99+/Zd8/umqKjIGNOzWPz3v/818+fPN8nJySYlJcUsXrzYtLe3h+G3Cb1A8Wlqarru9/W+ffusY/Tn+AQrzhifx/kBAAD0sYjr8wEAAPo3ig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOCo/wNZvNGki+7vXAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loss: 0.0005, norm: 0.0005, param_norm: 939.0556, vb_loss: 0.0001, ce_loss: 0.1722:  81%|████████  | 1260/1563 [01:37<00:23, 12.98it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 31\u001b[0m\n\u001b[1;32m     28\u001b[0m norm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(d3pm\u001b[38;5;241m.\u001b[39mx0_model\u001b[38;5;241m.\u001b[39mparameters(), \u001b[38;5;241m0.01\u001b[39m)\n\u001b[1;32m     30\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 31\u001b[0m     param_norm \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m([\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m d3pm\u001b[38;5;241m.\u001b[39mx0_model\u001b[38;5;241m.\u001b[39mparameters()])\n\u001b[1;32m     32\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m loss_ema \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m     33\u001b[0m     loss_ema \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
      "File \u001b[0;32m~/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/functional.py:1610\u001b[0m, in \u001b[0;36mnorm\u001b[0;34m(input, p, dim, keepdim, out, dtype)\u001b[0m\n\u001b[1;32m   1608\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m p \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfro\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dim, (\u001b[38;5;28mint\u001b[39m, torch\u001b[38;5;241m.\u001b[39mSymInt)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(dim) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m   1609\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m out \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1610\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinalg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvector_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1611\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1612\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mvector_norm(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m2\u001b[39m, _dim, keepdim, dtype\u001b[38;5;241m=\u001b[39mdtype, out\u001b[38;5;241m=\u001b[39mout)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "d3pm.train()\n",
    "\n",
    "n_epoch = 10\n",
    "device = 'cuda'\n",
    "from torchvision.utils import make_grid\n",
    "from matplotlib import pyplot as plt\n",
    "global_step = 0\n",
    "for i in range(n_epoch):\n",
    "    d3pm.train()\n",
    "    pbar = tqdm(dataloader)\n",
    "    loss_ema = None\n",
    "    for x, _ in pbar:\n",
    "        #print(x)\n",
    "        optim.zero_grad()\n",
    "        x = x.to(device)\n",
    "        # discritize x to 10 bins\n",
    "        x = (x * N).long().clamp(0, N - 1)\n",
    "        \n",
    "        loss, info = d3pm(x)\n",
    "        \n",
    "        # if loss.item() > 1000 or torch.isnan(loss):\n",
    "        #     print(f\"loss is too high, skipping, {loss.item()}\")\n",
    "        #     loss.backward()\n",
    "        #     optim.zero_grad()\n",
    "        #     continue\n",
    "        #print(loss.item())\n",
    "        loss.backward()\n",
    "        norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.01)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])\n",
    "        if loss_ema is None:\n",
    "            loss_ema = loss.item()\n",
    "        else:\n",
    "            loss_ema = 0.99 * loss_ema + 0.01 * loss.item()\n",
    "        pbar.set_description(f\"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}\")\n",
    "        optim.step()\n",
    "        global_step += 1\n",
    "\n",
    "        if global_step % 300 == 1:\n",
    "            d3pm.eval()\n",
    "\n",
    "            with torch.no_grad():\n",
    "                x = d3pm.sample(torch.randint(0, N, (4, 1, 32, 32)).cuda())\n",
    "                x_as_image = make_grid(x.float() / N, nrow=4)\n",
    "                plt.figure()\n",
    "                plt.imshow(x_as_image.permute(1, 2, 0).cpu().numpy())\n",
    "                plt.show()\n",
    "\n",
    "            d3pm.train()\n",
    "                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeuElEQVR4nO3dfWzV5f3/8Ve56RGlPbXc9EYKFlCYQHFjUDsVUSqlZoSbLsGbbDCZBFaI0HlXoyJufsvwHoNo4gJxEXEsAsNEUIotcRYclQ5R1wDpBEZbJlvPKcUW1l6/P5adn5W782lPefeU5yO5Es7n8+513p9csS8/55xeJ8Y55wQAwEXWzboBAMCliQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiR7WDXxXS0uLjh49qri4OMXExFi3AwDwyDmn+vp6paamqlu3c9/ndLoAOnr0qNLS0qzbAAC00+HDhzVgwIBznu+wl+BWrlypq6++WpdddpkyMzP1ySefhPVzcXFxHdUSAOAiutDv8w4JoLffflsFBQVasmSJPv30U40ePVo5OTk6duzYBX+Wl90AoGu44O9z1wHGjRvn8vPzQ4+bm5tdamqqKyoquuDPBgIBJ4nBYDAYUT4CgcB5f99H/A7o1KlTKi8vV3Z2duhYt27dlJ2drbKysjPqm5qaFAwGWw0AQNcX8QD6+uuv1dzcrKSkpFbHk5KSVFNTc0Z9UVGR/H5/aPABBAC4NJj/HVBhYaECgUBoHD582LolAMBFEPGPYfft21fdu3dXbW1tq+O1tbVKTk4+o97n88nn80W6DQBAJxfxO6DY2FiNGTNGxcXFoWMtLS0qLi5WVlZWpJ8OABClOuQPUQsKCjRr1iz98Ic/1Lhx4/Tiiy+qoaFBP//5zzvi6QAAUahDAmjmzJn65z//qSeeeEI1NTW6/vrrtWXLljM+mAAAuHTFOOecdRPfFgwG5ff7rdsAALRTIBBQfHz8Oc+bfwoOAHBpIoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAICJiAfQk08+qZiYmFZj+PDhkX4aAECU69ERk44YMULbtm37/0/So0OeBgAQxTokGXr06KHk5OSOmBoA0EV0yHtA+/fvV2pqqgYPHqx77rlHhw4dOmdtU1OTgsFgqwEA6PoiHkCZmZlas2aNtmzZolWrVqmqqko333yz6uvrz1pfVFQkv98fGmlpaZFuCQDQCcU451xHPkFdXZ0GDRqk559/XnPmzDnjfFNTk5qamkKPg8EgIQQAXUAgEFB8fPw5z3f4pwMSEhJ07bXX6sCBA2c97/P55PP5OroNAEAn0+F/B3TixAkdPHhQKSkpHf1UAIAoEvEAeuCBB1RaWqq///3v+vjjjzV9+nR1795dd911V6SfCgAQxSL+EtyRI0d011136fjx4+rXr59uuukm7dy5U/369Yv0UwEAoliHfwjBq2AwKL/fb90GAKCdLvQhBPaCAwCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJjr86xiAaNW9e/ewazvT9lELFizwVH/55ZeHXTts2DBPc+fn54dd++yzz3qa28sGx42NjZ7mXrZsmaf6pUuXeqrHf3EHBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATLAVDzrcwIEDw66NjY31NPePfvSjsGtvuukmT3MnJCSEXZuXl+dp7mh15MgRT/UrVqwIu3b69Ome5q6vrw+79q9//aunuUtLSz3Vo224AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiRjnnLNu4tuCwaD8fr91GziP73//+57qi4uLw65l7S++lpaWsGvvvfdeT3M3NDR4bSdsR48eDbv23//+t6e5KysrvbaDswgEAoqPjz/nee6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCih3UDiD5fffWVp/rjx4+HXXup7AW3a9cuT/V1dXVh1956662e5j516lTYtb///e89zQ2cD3dAAAATngNox44dmjJlilJTUxUTE6ONGze2Ou+c0xNPPKGUlBT16tVL2dnZ2r9/f6T6BQB0EZ4DqKGhQaNHj9bKlSvPen758uVasWKFXn31Ve3atUtXXHGFcnJy1NjY2O5mAQBdh+f3gHJzc5Wbm3vWc845vfjii3rsscc0depUSdIbb7yhpKQkbdy4UXfeeWf7ugUAdBkRfQ+oqqpKNTU1ys7ODh3z+/3KzMxUWVnZWX+mqalJwWCw1QAAdH0RDaCamhpJUlJSUqvjSUlJoXPfVVRUJL/fHxppaWmRbAkA0EmZfwqusLBQgUAgNA4fPmzdEgDgIohoACUnJ0uSamtrWx2vra0Nnfsun8+n+Pj4VgMA0PVFNIDS09OVnJys4uLi0LFgMKhdu3YpKysrkk8FAIhynj8Fd+LECR04cCD0uKqqShUVFUpMTNTAgQO1aNEi/eY3v9E111yj9PR0Pf7440pNTdW0adMi2TcAIMp5DqDdu3e32uqjoKBAkjRr1iytWbNGDz30kBoaGjR37lzV1dXppptu0pYtW3TZZZdFrmuY+te//uWp/sEHHwy79sc//rGnuffs2RN27YoVKzzN7UVFRYWn+ttvv91TfUNDQ9i1I0aM8DT3/fff76keiBTPATRhwgQ55855PiYmRk899ZSeeuqpdjUGAOjazD8FBwC4NBFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMx7nz76hgIBoPy+/3WbcCI16/jqK+vD7v2tdde8zT3nDlzwq796U9/6mnutWvXeqoHolEgEDjvf9PcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABM9rBsAvi0YDHbY3IFAoMPm/sUvfuGpft26dZ7qW1paPNUD0YA7IACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYiHHOOesmvi0YDMrv91u3gS7oiiuu8FS/efPmsGtvueUWT3Pn5uZ6qn///fc91QOdQSAQUHx8/DnPcwcEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBUPcA5DhgwJu/bTTz/1NHddXZ2n+g8//DDs2t27d3uae+XKlWHXdrJfF+jk2IoHANApEUAAABOeA2jHjh2aMmWKUlNTFRMTo40bN7Y6P3v2bMXExLQakydPjlS/AIAuwnMANTQ0aPTo0ed93Xjy5Mmqrq4OjbfeeqtdTQIAup4eXn8gNzf3gt9l4vP5lJyc3OamAABdX4e8B1RSUqL+/ftr2LBhmj9/vo4fP37O2qamJgWDwVYDAND1RTyAJk+erDfeeEPFxcX67W9/q9LSUuXm5qq5ufms9UVFRfL7/aGRlpYW6ZYAAJ2Q55fgLuTOO+8M/XvUqFHKyMjQkCFDVFJSookTJ55RX1hYqIKCgtDjYDBICAHAJaDDP4Y9ePBg9e3bVwcOHDjreZ/Pp/j4+FYDAND1dXgAHTlyRMePH1dKSkpHPxUAIIp4fgnuxIkTre5mqqqqVFFRocTERCUmJmrp0qXKy8tTcnKyDh48qIceekhDhw5VTk5ORBsHAEQ3z3vBlZSU6NZbbz3j+KxZs7Rq1SpNmzZNe/bsUV1dnVJTUzVp0iT9+te/VlJSUljzsxccotH06dM91a9evdpTfVxcnKd6Lx599NGwa9944w1Pc1dXV3ttB13IhfaC83wHNGHChPNuSLh161avUwIALkHsBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEx43guuo7EXHC4Fo0aN8lT/3HPPhV17tu/dipTXXnvNU/3TTz8ddu0//vEPr+2gk7vQXnDcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABNsxQNEgYSEhLBrp0yZ4mnu1atXh10bExPjae7t27eHXXv77bd7mhudH1vxAAA6JQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYC844BLX1NQUdm2PHj08zf2f//wn7NqcnBxPc5eUlHiqx8XHXnAAgE6JAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY8LavBoCIyMjI8FT/k5/8JOzasWPHeprb6/Y6XnzxxRdh1+7YsaPD+kDnxB0QAMCEpwAqKirS2LFjFRcXp/79+2vatGmqrKxsVdPY2Kj8/Hz16dNHvXv3Vl5enmprayPaNAAg+nkKoNLSUuXn52vnzp364IMPdPr0aU2aNEkNDQ2hmsWLF2vz5s1av369SktLdfToUc2YMSPijQMAopunF3+3bNnS6vGaNWvUv39/lZeXa/z48QoEAvrd736ntWvX6rbbbpMkrV69Wt/73ve0c+dO3XDDDZHrHAAQ1dr1HlAgEJAkJSYmSpLKy8t1+vRpZWdnh2qGDx+ugQMHqqys7KxzNDU1KRgMthoAgK6vzQHU0tKiRYsW6cYbb9TIkSMlSTU1NYqNjVVCQkKr2qSkJNXU1Jx1nqKiIvn9/tBIS0tra0sAgCjS5gDKz8/Xvn37tG7dunY1UFhYqEAgEBqHDx9u13wAgOjQpj8AWLBggd59913t2LFDAwYMCB1PTk7WqVOnVFdX1+ouqLa2VsnJyWedy+fzyefztaUNAEAU83QH5JzTggULtGHDBm3fvl3p6emtzo8ZM0Y9e/ZUcXFx6FhlZaUOHTqkrKysyHQMAOgSPN0B5efna+3atdq0aZPi4uJC7+v4/X716tVLfr9fc+bMUUFBgRITExUfH6+FCxcqKyuLT8ABAFrxFECrVq2SJE2YMKHV8dWrV2v27NmSpBdeeEHdunVTXl6empqalJOTo1deeSUizQIAuo4Y55yzbuLbgsGg/H6/dRuAhg0bFnbtwoULPc09ffp0T/Xneg/1YmtubvZUv23btrBr77jjDq/toJMLBAKKj48/53n2ggMAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACba9HUMQGfhZYuau+++29Pc+fn5YddeffXVnubuTHbv3h127dNPP+1p7j/96U9e28ElhDsgAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgLzh0uKSkpLBrR4wY4Wnul19+Oeza4cOHe5q7M9m1a1fYtc8884ynuTdt2hR2bUtLi6e5gfPhDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgKx4oMTHRU/1rr73mqf76668Pu3bw4MGe5u4sPv74Y0/1zz33nKf6rVu3hl37zTffeJobsMIdEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBdclMjMzPRU/+CDD4ZdO27cOE9zX3XVVZ7qOwuve6S99NJLYdf+3//9n6e5GxoaPNUDXRF3QAAAE54CqKioSGPHjlVcXJz69++vadOmqbKyslXNhAkTFBMT02rMmzcvok0DAKKfpwAqLS1Vfn6+du7cqQ8++ECnT5/WpEmTzng54b777lN1dXVoLF++PKJNAwCin6f3gLZs2dLq8Zo1a9S/f3+Vl5dr/PjxoeOXX365kpOTI9MhAKBLatd7QIFAQNKZX2j25ptvqm/fvho5cqQKCwt18uTJc87R1NSkYDDYagAAur42fwqupaVFixYt0o033qiRI0eGjt99990aNGiQUlNTtXfvXj388MOqrKzUO++8c9Z5ioqKtHTp0ra2AQCIUm0OoPz8fO3bt08fffRRq+Nz584N/XvUqFFKSUnRxIkTdfDgQQ0ZMuSMeQoLC1VQUBB6HAwGlZaW1ta2AABRok0BtGDBAr377rvasWOHBgwYcN7a//39yoEDB84aQD6fTz6fry1tAACimKcAcs5p4cKF2rBhg0pKSpSenn7Bn6moqJAkpaSktKlBAEDX5CmA8vPztXbtWm3atElxcXGqqamRJPn9fvXq1UsHDx7U2rVrdccdd6hPnz7au3evFi9erPHjxysjI6NDLgAAEJ08BdCqVask/fePTb9t9erVmj17tmJjY7Vt2za9+OKLamhoUFpamvLy8vTYY49FrGEAQNfg+SW480lLS1NpaWm7GsLZTZ8+vUPrO9KXX34Zdu3mzZs9zd3c3Bx27bPPPutp7rq6Ok/1ALxhLzgAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGAixl1of52LLBgMyu/3W7cBAGinQCCg+Pj4c57nDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJTwG0atUqZWRkKD4+XvHx8crKytJ7770XOt/Y2Kj8/Hz16dNHvXv3Vl5enmprayPeNAAg+nkKoAEDBmjZsmUqLy/X7t27ddttt2nq1Kn6/PPPJUmLFy/W5s2btX79epWWluro0aOaMWNGhzQOAIhyrp2uvPJK9/rrr7u6ujrXs2dPt379+tC5L7/80klyZWVlYc8XCAScJAaDwWBE+QgEAuf9fd/m94Cam5u1bt06NTQ0KCsrS+Xl5Tp9+rSys7NDNcOHD9fAgQNVVlZ2znmampoUDAZbDQBA1+c5gD777DP17t1bPp9P8+bN04YNG3TdddeppqZGsbGxSkhIaFWflJSkmpqac85XVFQkv98fGmlpaZ4vAgAQfTwH0LBhw1RRUaFdu3Zp/vz5mjVrlr744os2N1BYWKhAIBAahw8fbvNcAIDo0cPrD8TGxmro0KGSpDFjxugvf/mLXnrpJc2cOVOnTp1SXV1dq7ug2tpaJScnn3M+n88nn8/nvXMAQFRr998BtbS0qKmpSWPGjFHPnj1VXFwcOldZWalDhw4pKyurvU8DAOhiPN0BFRYWKjc3VwMHDlR9fb3Wrl2rkpISbd26VX6/X3PmzFFBQYESExMVHx+vhQsXKisrSzfccENH9Q8AiFKeAujYsWP62c9+purqavn9fmVkZGjr1q26/fbbJUkvvPCCunXrpry8PDU1NSknJ0evvPJKhzQOAIhuMc45Z93EtwWDQfn9fus2AADtFAgEFB8ff87z7AUHADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMNHpAqiTbcwAAGijC/0+73QBVF9fb90CACACLvT7vNPtBdfS0qKjR48qLi5OMTExoePBYFBpaWk6fPjwefcWinZcZ9dxKVyjxHV2NZG4Tuec6uvrlZqaqm7dzn2f4/kL6Tpat27dNGDAgHOej4+P79KL/z9cZ9dxKVyjxHV2Ne29znA2le50L8EBAC4NBBAAwETUBJDP59OSJUvk8/msW+lQXGfXcSlco8R1djUX8zo73YcQAACXhqi5AwIAdC0EEADABAEEADBBAAEATERNAK1cuVJXX321LrvsMmVmZuqTTz6xbiminnzyScXExLQaw4cPt26rXXbs2KEpU6YoNTVVMTEx2rhxY6vzzjk98cQTSklJUa9evZSdna39+/fbNNsOF7rO2bNnn7G2kydPtmm2jYqKijR27FjFxcWpf//+mjZtmiorK1vVNDY2Kj8/X3369FHv3r2Vl5en2tpao47bJpzrnDBhwhnrOW/ePKOO22bVqlXKyMgI/bFpVlaW3nvvvdD5i7WWURFAb7/9tgoKCrRkyRJ9+umnGj16tHJycnTs2DHr1iJqxIgRqq6uDo2PPvrIuqV2aWho0OjRo7Vy5cqznl++fLlWrFihV199Vbt27dIVV1yhnJwcNTY2XuRO2+dC1ylJkydPbrW2b7311kXssP1KS0uVn5+vnTt36oMPPtDp06c1adIkNTQ0hGoWL16szZs3a/369SotLdXRo0c1Y8YMw669C+c6Jem+++5rtZ7Lly836rhtBgwYoGXLlqm8vFy7d+/WbbfdpqlTp+rzzz+XdBHX0kWBcePGufz8/NDj5uZml5qa6oqKigy7iqwlS5a40aNHW7fRYSS5DRs2hB63tLS45ORk98wzz4SO1dXVOZ/P59566y2DDiPju9fpnHOzZs1yU6dONemnoxw7dsxJcqWlpc65/65dz5493fr160M1X375pZPkysrKrNpst+9ep3PO3XLLLe7++++3a6qDXHnlle7111+/qGvZ6e+ATp06pfLycmVnZ4eOdevWTdnZ2SorKzPsLPL279+v1NRUDR48WPfcc48OHTpk3VKHqaqqUk1NTat19fv9yszM7HLrKkklJSXq37+/hg0bpvnz5+v48ePWLbVLIBCQJCUmJkqSysvLdfr06VbrOXz4cA0cODCq1/O71/k/b775pvr27auRI0eqsLBQJ0+etGgvIpqbm7Vu3To1NDQoKyvroq5lp9uM9Lu+/vprNTc3KykpqdXxpKQk/e1vfzPqKvIyMzO1Zs0aDRs2TNXV1Vq6dKluvvlm7du3T3FxcdbtRVxNTY0knXVd/3euq5g8ebJmzJih9PR0HTx4UI8++qhyc3NVVlam7t27W7fnWUtLixYtWqQbb7xRI0eOlPTf9YyNjVVCQkKr2mhez7NdpyTdfffdGjRokFJTU7V37149/PDDqqys1DvvvGPYrXefffaZsrKy1NjYqN69e2vDhg267rrrVFFRcdHWstMH0KUiNzc39O+MjAxlZmZq0KBB+sMf/qA5c+YYdob2uvPOO0P/HjVqlDIyMjRkyBCVlJRo4sSJhp21TX5+vvbt2xf171FeyLmuc+7cuaF/jxo1SikpKZo4caIOHjyoIUOGXOw222zYsGGqqKhQIBDQH//4R82aNUulpaUXtYdO/xJc37591b179zM+gVFbW6vk5GSjrjpeQkKCrr32Wh04cMC6lQ7xv7W71NZVkgYPHqy+fftG5douWLBA7777rj788MNWX5uSnJysU6dOqa6urlV9tK7nua7zbDIzMyUp6tYzNjZWQ4cO1ZgxY1RUVKTRo0frpZdeuqhr2ekDKDY2VmPGjFFxcXHoWEtLi4qLi5WVlWXYWcc6ceKEDh48qJSUFOtWOkR6erqSk5NbrWswGNSuXbu69LpK0pEjR3T8+PGoWlvnnBYsWKANGzZo+/btSk9Pb3V+zJgx6tmzZ6v1rKys1KFDh6JqPS90nWdTUVEhSVG1nmfT0tKipqami7uWEf1IQwdZt26d8/l8bs2aNe6LL75wc+fOdQkJCa6mpsa6tYj51a9+5UpKSlxVVZX785//7LKzs13fvn3dsWPHrFtrs/r6erdnzx63Z88eJ8k9//zzbs+ePe6rr75yzjm3bNkyl5CQ4DZt2uT27t3rpk6d6tLT090333xj3Lk357vO+vp698ADD7iysjJXVVXltm3b5n7wgx+4a665xjU2Nlq3Hrb58+c7v9/vSkpKXHV1dWicPHkyVDNv3jw3cOBAt337drd7926XlZXlsrKyDLv27kLXeeDAAffUU0+53bt3u6qqKrdp0yY3ePBgN378eOPOvXnkkUdcaWmpq6qqcnv37nWPPPKIi4mJce+//75z7uKtZVQEkHPOvfzyy27gwIEuNjbWjRs3zu3cudO6pYiaOXOmS0lJcbGxse6qq65yM2fOdAcOHLBuq10+/PBDJ+mMMWvWLOfcfz+K/fjjj7ukpCTn8/ncxIkTXWVlpW3TbXC+6zx58qSbNGmS69evn+vZs6cbNGiQu++++6Luf57Odn2S3OrVq0M133zzjfvlL3/prrzySnf55Ze76dOnu+rqarum2+BC13no0CE3fvx4l5iY6Hw+nxs6dKh78MEHXSAQsG3co3vvvdcNGjTIxcbGun79+rmJEyeGwse5i7eWfB0DAMBEp38PCADQNRFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDx/wAzgwe7F079ZQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# visualize first dataset\n",
    "x = next(iter(dataloader))[0][:1]\n",
    "x_as_image = make_grid(x.float(), nrow=1)\n",
    "plt.figure()\n",
    "plt.imshow(x_as_image.permute(1, 2, 0).cpu().numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABkQAAAMtCAYAAADOtR3+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAq9ElEQVR4nO3df2yV9dn48avABGQtGxILlSpuyzYbYxuhEIZbRPtI1JBhtojbMhA3l23VSaoukiWwZSz+yGKI46jbMsU5psRlMLMsNYwtMA2TCmFh63RTWazD8kO3FjpFpef5w9jnywSl/bY9nIvXKzl/nPuc+9zX/Qcf2rx7n7uiWCwWAwAAAAAAILERpR4AAAAAAABgqAkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJDeqFIP0F+9vb2xe/fuqKysjIqKilKPAwAAAAAAlFCxWIwDBw5ETU1NjBhx7OtAyi6I7N69O2pra0s9BgAAAAAAcALp6OiIKVOmHPP1sgsilZWVEfHWiVVVVZV4GgAAAAAAoJS6u7ujtra2rx8cS9kFkbe/JquqqkoQAQAAAAAAIiLe8zYbbqoOAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6wx5EOjo64sILL4y6uro477zz4pFHHhnuEQAAAAAAgJPMqGE/4KhRsXLlymhoaIjOzs6YNm1aXHbZZTFu3LjhHgUAAAAAADhJDHsQmTx5ckyePDkiIiZNmhQTJ06MV155RRABAAAAAACGTL+/Mmvz5s0xb968qKmpiYqKili/fv073lMoFGLq1KkxZsyYmDlzZmzduvWon7Vt27Y4fPhw1NbW9ntwAAAAAACA49XvINLT0xP19fVRKBSO+vratWujpaUlli9fHtu3b4/6+vqYO3du7N2794j3vfLKK7Fw4cL40Y9+9K7HO3ToUHR3dx/xAAAAAAAA6I+KYrFYHPDOFRWxbt26mD9/ft+2mTNnRmNjY6xatSoiInp7e6O2tjauv/76uOWWWyLircjxP//zP3HttdfGF7/4xXc9xre//e34zne+847tXV1dUVVVNdDRAQCABF685Q8D2m/KbZ8c5EkAAIBS6e7ujvHjx79nN+j3FSLv5vXXX49t27ZFU1PT/x1gxIhoamqKLVu2REREsViMq6++Oi666KL3jCEREUuXLo2urq6+R0dHx2CODAAAAAAAnAQGNYjs378/Dh8+HNXV1Udsr66ujs7OzoiIeOKJJ2Lt2rWxfv36aGhoiIaGhti5c+cxP3P06NFRVVV1xAMAAAAAAKA/Rg33AS+44ILo7e0d7sMCAAAAAAAnsUG9QmTixIkxcuTI2LNnzxHb9+zZE5MmTRrMQwEAAAAAABy3QQ0ip5xySkybNi02btzYt623tzc2btwYs2bNGsxDAQAAAAAAHLd+f2XWwYMH49lnn+17vmvXrtixY0dMmDAhzjzzzGhpaYlFixbF9OnTY8aMGbFy5cro6emJxYsXD+rgAAAAAAAAx6vfQeSpp56KOXPm9D1vaWmJiIhFixbF6tWrY8GCBbFv375YtmxZdHZ2RkNDQ7S2tr7jRusAAAAAAADDpaJYLBZLPUR/dHd3x/jx46OrqyuqqqpKPQ4AAFBCL97yhwHtN+W2Tw7yJAAAQKkcbzcY1HuIAAAAAAAAnIgEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEivbIJIoVCIurq6aGxsLPUoAAAAAABAmSmbINLc3Bzt7e3R1tZW6lEAAAAAAIAyUzZBBAAAAAAAYKAEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACC9sgkihUIh6urqorGxsdSjAAAAAAAAZaZsgkhzc3O0t7dHW1tbqUcBAAAAAADKTNkEEQAAAAAAgIESRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAIL2yCSKFQiHq6uqisbGx1KMAAAAAAABlpmyCSHNzc7S3t0dbW1upRwEAAAAAAMpM2QQRAAAAAACAgRJEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANIrmyBSKBSirq4uGhsbSz0KAAAAAABQZsomiDQ3N0d7e3u0tbWVehQAAAAAAKDMlE0QAQAAAAAAGChBBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvbIJIoVCIerq6qKxsbHUowAAAAAAAGWmbIJIc3NztLe3R1tbW6lHAQAAAAAAykzZBBEAAAAAAICBEkQAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPTKJogUCoWoq6uLxsbGUo8CAAAAAACUmbIJIs3NzdHe3h5tbW2lHgUAAAAAACgzZRNEAAAAAAAABkoQAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASK9sgkihUIi6urpobGws9SgAAAAAAECZKZsg0tzcHO3t7dHW1lbqUQAAAAAAgDJTNkEEAAAAAABgoAQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvbIJIoVCIerq6qKxsbHUowAAAAAAAGWmbIJIc3NztLe3R1tbW6lHAQAAAAAAykzZBBEAAAAAAICBEkQAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvZIEkSuuuCI++MEPxmc/+9lSHB4AAAAAADjJlCSI3HDDDfHTn/60FIcGAAAAAABOQiUJIhdeeGFUVlaW4tAAAAAAAMBJqN9BZPPmzTFv3ryoqamJioqKWL9+/TveUygUYurUqTFmzJiYOXNmbN26dTBmBQAAAAAAGJB+B5Genp6or6+PQqFw1NfXrl0bLS0tsXz58ti+fXvU19fH3LlzY+/evQMa8NChQ9Hd3X3EAwAAAAAAoD/6HUQuvfTSWLFiRVxxxRVHff3OO++Ma6+9NhYvXhx1dXVx7733xqmnnhr33XffgAa89dZbY/z48X2P2traAX0OAAAAAABw8hrUe4i8/vrrsW3btmhqavq/A4wYEU1NTbFly5YBfebSpUujq6ur79HR0TFY4wIAAAAAACeJUYP5Yfv374/Dhw9HdXX1Edurq6vj6aef7nve1NQUf/rTn6KnpyemTJkSjzzySMyaNeuonzl69OgYPXr0YI4JAAAAAACcZAY1iByv3/72t6U4LAAAAAAAcJIa1K/MmjhxYowcOTL27NlzxPY9e/bEpEmTBvNQAAAAAAAAx21Qg8gpp5wS06ZNi40bN/Zt6+3tjY0bNx7zK7EAAAAAAACGWr+/MuvgwYPx7LPP9j3ftWtX7NixIyZMmBBnnnlmtLS0xKJFi2L69OkxY8aMWLlyZfT09MTixYsHdXAAAAAAAIDj1e8g8tRTT8WcOXP6nre0tERExKJFi2L16tWxYMGC2LdvXyxbtiw6OzujoaEhWltb33GjdQAAAAAAgOFSUSwWi6Ueoj+6u7tj/Pjx0dXVFVVVVaUeBwAAKKEXb/nDgPabctsnB3kSAACgVI63GwzqPUQAAAAAAABORIIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApFc2QaRQKERdXV00NjaWehQAAAAAAKDMlE0QaW5ujvb29mhrayv1KAAAAAAAQJkZVeoB+qtYLEZERHd3d4knAQAASu3AoZ4B7ef3CQAAyOPtn+/f7gfHUnZB5MCBAxERUVtbW+JJAACAsrWy1AMAAACD7cCBAzF+/Phjvl5RfK9kcoLp7e2N3bt3R2VlZVRUVJR6HChb3d3dUVtbGx0dHVFVVVXqcYCErDPAULPOAEPNOgMMJWsMDJ5isRgHDhyImpqaGDHi2HcKKbsrREaMGBFTpkwp9RiQRlVVlf90gSFlnQGGmnUGGGrWGWAoWWNgcLzblSFvK5ubqgMAAAAAAAyUIAIAAAAAAKQniMBJavTo0bF8+fIYPXp0qUcBkrLOAEPNOgMMNesMMJSsMTD8yu6m6gAAAAAAAP3lChEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBJI6cOBALFmyJM4666wYO3ZsfOITn4i2trZ33efQoUPxrW99K84666wYPXp0TJ06Ne67775hmhgoNwNZZ9asWRP19fVx6qmnxuTJk+Oaa66Jl19+eZgmBk5kmzdvjnnz5kVNTU1UVFTE+vXrj3i9WCzGsmXLYvLkyTF27NhoamqKv//97+/5uYVCIaZOnRpjxoyJmTNnxtatW4foDIAT3VCsM7feems0NjZGZWVlnH766TF//vx45plnhvAsgBPZUP0887bbbrstKioqYsmSJYM7OJxEBBFI6stf/nJs2LAhHnzwwdi5c2dccskl0dTUFP/85z+Puc+VV14ZGzdujJ/85CfxzDPPxEMPPRQf+9jHhnFqoJz0d5154oknYuHChfGlL30p/vKXv8QjjzwSW7dujWuvvXaYJwdORD09PVFfXx+FQuGor99xxx1x1113xb333htPPvlkjBs3LubOnRuvvfbaMT9z7dq10dLSEsuXL4/t27dHfX19zJ07N/bu3TtUpwGcwIZindm0aVM0NzfHH//4x9iwYUO88cYbcckll0RPT89QnQZwAhuKdeZtbW1t8cMf/jDOO++8wR4bTioVxWKxWOohgMH16quvRmVlZfzqV7+Kyy+/vG/7tGnT4tJLL40VK1a8Y5/W1ta46qqr4vnnn48JEyYM57hAGRrIOvP9738/7rnnnnjuuef6tv3gBz+I22+/PV588cVhmRsoDxUVFbFu3bqYP39+RLz115Q1NTVx4403xk033RQREV1dXVFdXR2rV6+Oq6666qifM3PmzGhsbIxVq1ZFRERvb2/U1tbG9ddfH7fccsuwnAtwYhqsdea/7du3L04//fTYtGlTfOpTnxqq8YEyMJjrzMGDB+P888+Pu+++O1asWBENDQ2xcuXKYTgLyMcVIpDQm2++GYcPH44xY8YcsX3s2LHx+OOPH3WfRx99NKZPnx533HFHnHHGGfHRj340brrppnj11VeHY2SgzAxknZk1a1Z0dHTEb37zmygWi7Fnz574xS9+EZdddtlwjAyUsV27dkVnZ2c0NTX1bRs/fnzMnDkztmzZctR9Xn/99di2bdsR+4wYMSKampqOuQ9w8hrIOnM0XV1dERH+yAx4h/+fdaa5uTkuv/zyI/YFBmZUqQcABl9lZWXMmjUrvvvd78Y555wT1dXV8dBDD8WWLVviIx/5yFH3ef755+Pxxx+PMWPGxLp162L//v3x9a9/PV5++eW4//77h/kMgBPdQNaZ2bNnx5o1a2LBggXx2muvxZtvvhnz5s075uXkAG/r7OyMiIjq6uojtldXV/e99t/2798fhw8fPuo+Tz/99NAMCpStgawz/623tzeWLFkSs2fPjnPPPXfQZwTK20DXmYcffji2b9/+nvdrBI6PK0QgqQcffDCKxWKcccYZMXr06Ljrrrvic5/7XIwYcfR/9r29vVFRURFr1qyJGTNmxGWXXRZ33nlnPPDAA64SAY6qv+tMe3t73HDDDbFs2bLYtm1btLa2xj/+8Y/46le/OsyTAwAMvubm5vjzn/8cDz/8cKlHAZLo6OiIG264IdasWfOOq/OBgRFEIKkPf/jDsWnTpjh48GB0dHTE1q1b44033ogPfehDR33/5MmT44wzzojx48f3bTvnnHOiWCz6bn/gqPq7ztx6660xe/bsuPnmm+O8886LuXPnxt133x333XdfvPTSS8M8PVBOJk2aFBERe/bsOWL7nj17+l77bxMnToyRI0f2ax/g5DWQdeb/dd1118Wvf/3r+P3vfx9TpkwZkhmB8jaQdWbbtm2xd+/eOP/882PUqFExatSo2LRpU9x1110xatSoOHz48JDPDdkIIpDcuHHjYvLkyfGvf/0rHnvssfj0pz991PfNnj07du/eHQcPHuzb9re//S1GjBjhB3rgXR3vOvOf//znHVePjBw5MiLeusEgwLGcffbZMWnSpNi4cWPftu7u7njyySdj1qxZR93nlFNOiWnTph2xT29vb2zcuPGY+wAnr4GsMxFv/Qxz3XXXxbp16+J3v/tdnH322cMxLlCGBrLOXHzxxbFz587YsWNH32P69OnxhS98IXbs2NH3+xRw/AQRSOqxxx6L1tbW2LVrV2zYsCHmzJkTH//4x2Px4sUREbF06dJYuHBh3/s///nPx2mnnRaLFy+O9vb22Lx5c9x8881xzTXXxNixY0t1GsAJrL/rzLx58+KXv/xl3HPPPfH888/HE088Ed/4xjdixowZUVNTU6rTAE4QBw8e7PtFP+KtG4/u2LEjXnjhhaioqIglS5bEihUr4tFHH42dO3fGwoULo6amJubPn9/3GRdffHGsWrWq73lLS0v8+Mc/jgceeCD++te/xte+9rXo6enpW6eAk8tQrDPNzc3xs5/9LH7+859HZWVldHZ2Rmdnp68dhpPUYK8zlZWVce655x7xGDduXJx22mnuVQQD5KbqkFRXV1csXbo0XnzxxZgwYUJ85jOfie9973vxvve9LyIiXnrppXjhhRf63v/+978/NmzYENdff31Mnz49TjvttLjyyitjxYoVpToF4ATX33Xm6quvjgMHDsSqVavixhtvjA984ANx0UUXxe23316qUwBOIE899VTMmTOn73lLS0tERCxatChWr14d3/zmN6Onpye+8pWvxL///e+44IILorW19Yjv037uuedi//79fc8XLFgQ+/bti2XLlkVnZ2c0NDREa2vrO25mCpwchmKdueeeeyIi4sILLzziWPfff39cffXVQ3cywAlpKNYZYHBVFH1HBQAAAAAAkJyvzAIAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACC9/wUt71o0Sbl9QgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 2000x1000 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# check parameter distribution\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "def plot_param_dist(model):\n",
    "    plot_dist = {}\n",
    "    for name, param in model.named_parameters():\n",
    "        if param.requires_grad:\n",
    "            plot_dist[name] = param.detach().cpu().numpy().flatten()\n",
    "            # sample 1000\n",
    "            # replace nan with 1000\n",
    "            plot_dist[name] = np.where(np.isnan(plot_dist[name]), 10, plot_dist[name])\n",
    "            plot_dist[name] = np.random.choice(plot_dist[name], 1000)\n",
    "    \n",
    "    plt.figure(figsize=(20, 10))\n",
    "\n",
    "    for idx, (name, dist) in enumerate(plot_dist.items()):\n",
    "    \n",
    "        plt.hist(dist, bins=100, density=True)\n",
    "        \n",
    "    # log y\n",
    "    plt.yscale('log')\n",
    "\n",
    "    #plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_param_dist(d3pm.x0_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f72c4434fe0>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAACxCAYAAADwMnaUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkhklEQVR4nO3dfWxV5R0H8G8LtKDQy4uhtSuFbmNDJ1YHWjvMcNrJiFEYZFPDRlESoyuO0mRTtqHGzVVdNl58gbktwDIZjkRgkKhhRUpMylsRJuIqy5igeMucay+ilI4++2Pzjvvr8f7O755T7mn5fpKb9Nzz9pznvPTJeX739+Q45xyIiIiIIiA32wUgIiIi+hgbJkRERBQZbJgQERFRZLBhQkRERJHBhgkRERFFBhsmREREFBlsmBAREVFksGFCREREkcGGCREREUUGGyZEREQUGT3WMHnqqacwZswYDBw4EBUVFdi1a1dP7YqIiIj6iJyeGCvnueeew+zZs7FixQpUVFRgyZIlWLduHVpaWjBy5Mi063Z1deHYsWMYMmQIcnJywi4aERER9QDnHE6cOIHi4mLk5mb+3qNHGiYVFRW46qqr8OSTTwL4b2Nj1KhRuPfee3H//fenXfftt9/GqFGjwi4SERERnQNHjx5FSUlJxuuH3pVz+vRpNDc3o6qq6v87yc1FVVUVmpqaui3f0dGBRCKR/HCwYyIiot5ryJAhgdYPvWHy3nvv4cyZMygsLEz5vrCwEPF4vNvy9fX1iMViyU9paWnYRSIiIqJzJGgYRtZ/lbNw4UK0t7cnP0ePHs12kYiIiChL+oe9wYsuugj9+vVDa2tryvetra0oKirqtnx+fj7y8/PDLgYRERH1QqG/McnLy8OECRPQ0NCQ/K6rqwsNDQ2orKwMe3dERETUh4T+xgQA6urqUF1djYkTJ+Lqq6/GkiVLcPLkSdxxxx09sTsiIiLqI3qkYXLrrbfiH//4Bx544AHE43FcccUVePHFF7sFxGYqFoulTI8ZMyZluq2tLe20XN7L3//+95TpoUOHpl1e7kNb3lpmKezy+NmnRtuHnN/Y2Jh2+dGjR6dMa8dkPWY/x+un3oIIet3I61TS6shrGeu1L+dfccUVKdOrV69Ou355eXnKdNDrMAzaedGm9+3bl3a+9VqWMnl+aNeOVobrrrsu7fr79+/vXtCzyPvZWj5NJs+8ntiGZX3rvZYJ6zFp18Fbb70VvFCKHmmYAMC8efMwb968nto8ERER9UFZ/1UOERER0cfYMCEiIqLI6LGunJ40Z86clGnZByb7dyWtf9iL1hdonQ47psS6Pz/kNoLGN1hp/eja/uT6Wr+/F61/dvr06SnT27ZtU7eZbvtyexs2bEi7vGStA6D7eZSxBNryWoyKRt7Psoxye2Fcd1o9yjgZLYZEK4Oclscsz7MWB5dJbELQY5bXhbzWtRgTuX1ZHnlMPfG8sT77ZZnledfqVIvTsZ4Trc68aM9ROV/uU57ncxFjwjcmREREFBlsmBAREVFksGFCREREkdErY0y0/l2t3172K3rFBch+N61vUOv7C9qXKGnHILcnj8erb1UrszV3gpyWZdLymGjry+tAq0NZfq9+fFmPchl5LVn7+rXzoPWzy/1L2nXgtf6SJUvSlkETNPeDLLM1ZkWLffCKmdHi0mS9yTJo16IW46U9w7T9yxgV7br1KoPcpiyTtg9rPJVWHu260eK5tOeF1z7ktbFq1aq0+9TOq/UZqj2ztPX93CvyGLU4GTlf7lOLJQoD35gQERFRZLBhQkRERJHBhgkRERFFRo5zzmW7EGdLJBLdxsKRtLE1tH45P3EB1rEqgo55YM0zEnTsjUyO2Vomrf+zvb097fas51krTyaxENZxZIKOgWLVE/vX1tHiK6x90pMnT06Ztsa4+MnVYt2G9Txr2wt63q3PND9lsj6ztBw7Wn4L65hI1ntP8tq+fObI/zXW82ytMxnDEjTXU9jPE0DPoeMnxqS9vR0FBQUZl4FvTIiIiCgy2DAhIiKiyGDDhIiIiCKDDRMiIiKKjF4Z/Dp69Oi0860BTF7zMwkus5TBGrgVNDivJwb5ClonWhCVdp6tMgkA1urVGkyn7U8KGkTdE8Gv2vKSFhQp6yzovXcuaMccdtC0dft+BA28lLTzrAU5B73O/NShFritCZpMMOwfGGj/J/ywlsnPIH4MfiUiIqI+gw0TIiIiigw2TIiIiCgyeuUgfkEH0Mok3kJLmGTtE9b69eTAS9bkQpnEg1gHqLKy9mVqsQbWAa606wTQ612yDsKnzdeSyFmPORPW5GCyzmTiLY024KRkvTe9yt/TMWDatWaNo9GSF2aS0E0bQE4egzYYnMYaw5bJ/avRyqwdozXGI2hskfX/ip9ltGvtoYceSjt9LvCNCREREUUGGyZEREQUGWyYEBERUWT0yhgTrd9P60OTfZlefeLWmJJMcqWk2582MJp2jJnkBLDGL1jjH4LmPQmaE8RPPMa2bdvSLmOtV61PO+ycIXJ9beA1r21a+8XlMdbW1qZML1iwoHtB05DnQNLuRT/5MayxBFq8hSyzNZ+E9RxosUpe14k8hqDxU3K+lt9CLi+vzSVLlqRd3hpL5HUOtPtBqyM5CJ8sc9DYQWvMiZ94Kuv9Io9J1tnSpUu77SNsfGNCREREkcGGCREREUUGGyZEREQUGb0yxiRoX6PsV/Tqj7X27cnfest+di1nQCZ9h5b1pUzGbLDGtQSNMdHqwBpbJNeX/b1erH35WmyQFp9hFcb4IUHH75Dryz5qjTX3i7a+n9wOWhyLJLdpjYPR5gcdcyWT3C1Bc+gELbP1XrHmVfKqA7lP6zHJ57q1DuR1EzR2MZOcXNo2tPipc4FvTIiIiCgy2DAhIiKiyDA3TLZv346bb74ZxcXFyMnJ6dYt4pzDAw88gIsvvhiDBg1CVVUVDh06FFZ5iYiIqA8zx5icPHkS5eXluPPOOzFjxoxu8x9//HEsW7YMq1evRllZGRYtWoQpU6bg4MGDGDhwYCiFlqz5N2Qfmp/ffmu5T6z96tr2wx6Xxs/2tWWs/alB4yes/a+y/JI1JsVrm/KYtPE9tDLIHAEyT4I1H0YmcQHaMvIYrf3kWn4La04Q7bqU25N16rVNydovr+UhCjoGU9AcQ372Yc1fE5R2DNp1Zr03vLah3Y9yeeu1r+VNseY10Xgtrz2j5D5lGTOp56DMDZOpU6di6tSpnvOcc1iyZAl+9KMfYdq0aQCA3/72tygsLMSGDRtw2223BSstERER9WmhxpgcPnwY8XgcVVVVye9isRgqKirQ1NTkuU5HRwcSiUTKh4iIiM5PoTZM4vE4AKCwsDDl+8LCwuQ8qb6+HrFYLPkZNWpUmEUiIiKiXiTreUwWLlyIurq65HQikVAbJ9p4IFo/YCaxE1q/e9Df+Gu/FbfGMmjHOGfOnG7r+Mnvko52XuT2tdgDaz+/dsxaXABgH9tGy8EhyeXlebfGGmnHqO3fixZTYs3xo7HmBLHWsR/WZ4Q1BkSL/dFimYLmOfIqk1bvYZ9nub4WyyC3L3OIaM95P2XQ6lEbO0c+4+R8OS1jSuR5l/mwrM9gP+MDWfcRdmyRH6G+MSkqKgIAtLa2pnzf2tqanCfl5+ejoKAg5UNERETnp1AbJmVlZSgqKkJDQ0Pyu0QigZ07d6KysjLMXREREVEfZO7K+eCDD/DXv/41OX348GHs27cPw4cPR2lpKWpra/GTn/wEY8eOTf5cuLi4uNvrJCIiIiIpxznnLCts27YNX/nKV7p9X11djVWrVsE5hwcffBDPPPMM2tracO211+Lpp5/G5z73OV/bTyQSiMViaZcpLy9PO1+L99DGN/FaRusD1n6Tbx3XIWgsg9Yv6LW+dcwSrV61etZiTOR1YK1j2Z/rJ6+KNUZDy19h7aeX/ehye9Z+eW2+13daLgfreCP79+/vts+zjR49Ou36Vpn0iVvjWHo6/kkrj/b88RLGuCvptqfdz/K5HTRuRnseeLHGgFhjzqx1rD0jrXFyXnFDQf83Sdp5BoD29vZAYRnmNybXXXcd0rVlcnJy8PDDD+Phhx/OuFBERER0fuJYOURERBQZbJgQERFRZGQ9j0kmtLEwtH4/LQeJ13dB+xq12ARrnhKZh8RrPJB02/PavhZTosUryOXlMcn1tb7KoPETWp4Gr355a1yNnxiOs2ljY8gYEut4QVr5vVjzlmi5Vqx91kHzZWSSh0E7b0H7/q3PHGtMidyfjE3yGrtLiyWS503uU+5De+ZotPPq5zl9Nj/Xodym9dq33u9S2HEzfnJyaceg5WaRdeInxiQovjEhIiKiyGDDhIiIiCKDDRMiIiKKDHMek57mJ4/J5MmTU6a1/tqgeRG8tqH17Wl9zHL9oONOBI1F8CpT0HgGrb9048aN3cpwNpnfImjOkEzy1wQdH0T2y8txKrQ6sl7bYVz71tgi7brQ8phoeYkkuT8Zt6PFFvmh5QWxxpRpeZA01txMfuLmZJm05bX73Xqegz5TJa/xvyQtHlHSxvsKOzeM9f+IFhPjZx/WODrtuQ0Ez2PCNyZEREQUGWyYEBERUWSwYUJERESR0SvzmFj7PjPJYxI03sLaLy+F/Xv5TMZH0MaR0Oo5kzFLzqbFX2jT8jzL/uFMxtaw9hlrY9v4iXuxkHXuJ79FJnlAzib79uXyWuyBNWZF8jNOjCSvZUnru7fm1LHmdtFiCSQ/97e2jPZcDCNWLx3tOtByPVnjR7yWkfeflqvFep6sOYCsz1Q/512brz03zwW+MSEiIqLIYMOEiIiIIoMNEyIiIoqMPpHHxJozwM+4MVouB2sMirWfz9onbc1v4acM1t+7W2ljLsjzLM+B7M+15nrwquNMxpqx7FMTNJdDJtehNcbEei03NjamXV7mt9DuPY2f8gbdh7bPnjhP6eZnksfEmjNHi4+w3s/atW6NgfFDW8caO2SNh9KeJ1oOEW1/XjEuWgyYFnsn47FWr17dbR8S85gQERFRn8GGCREREUUGGyZEREQUGWyYEBERUWT0iQRrkhbsqi0P2INdwx4AT7JuT/KzfNiBljLoSgZZacFyUtA6kgPoeSUbswbcWROaWQci1FivUy/aebYm3rImidMC/iRroKjX8QUNZtcS5WnHYE1iJfevDS7n57xr16KkBQxr97P1PMnty+eJvLf8BCAHfY5atydZk1Rq/ATny4BYazLQbOAbEyIiIooMNkyIiIgoMtgwISIiosjolTEmWgI1SYsb8BNvofXHav10WkxK2P16YSQfsiZkCjshmzWmRBuYTfZJ90SCNTnoV9CkUNY6yyTpnZZIy5poS5ZBiz2w1okWF+BnYLWg50Hbh5b4SsaIWGNerDEpXtuwxr1pz0QtkZ58ZmrPcbl/GSMmZZJkzjpfqyPt/4QWUybrRLI+EwE9Nkd7bgYdWDQTfGNCREREkcGGCREREUUGGyZEREQUGb0yxkT2A1r73bT4ED/LaPEYWg4PrYzafOvv5eX+/fS/hv37dmu8hjWfhYzvsPbfem1Tm6/F1Vj7pK05RbQ4AHkdyDryWkfS+v6D0gak0+5FWR5573nFmGi0e0Hrd5f7lMck52sxZ1reFMmrfFo9a/UUNF+Ndu1q86287ndr3Jyk3V9anWjxW3J9a3m8yJgS7bzL+Kds4BsTIiIiigxTw6S+vh5XXXUVhgwZgpEjR2L69OloaWlJWebUqVOoqanBiBEjMHjwYMycOROtra2hFpqIiIj6JlPDpLGxETU1NdixYwe2bNmCzs5O3HjjjTh58mRymQULFmDTpk1Yt24dGhsbcezYMcyYMSP0ghMREVHfY4oxefHFF1OmV61ahZEjR6K5uRlf/vKX0d7ejt/85jdYs2YNrr/+egDAypUrcckll2DHjh245pprQim01gem9e9qeRkAPQ+BZO2vteZBkGWWfZtyvlYHfspkzWOQyT7TseaCkLQ68WLtB9fiH6x9xtr+tetEkrFFXte63IfWbx52LJIWM2bt19fGJ/KzD4113Bc/8U3pti+vdWt+DK8yafvUjlHW6/79+9MuL4/Ba6yqs1njbPzEHmr3lzUGRYuD0Z4Hsk60543kJw4naPxiNsbOCRRj0t7eDgAYPnw4AKC5uRmdnZ2oqqpKLjNu3DiUlpaiqakpyK6IiIjoPJDxr3K6urpQW1uLSZMm4bLLLgMAxONx5OXldWuBFRYWIh6Pe26no6MDHR0dyelEIpFpkYiIiKiXy/iNSU1NDQ4cOIC1a9cGKkB9fT1isVjyM2rUqEDbIyIiot4rozcm8+bNw+bNm7F9+3aUlJQkvy8qKsLp06fR1taW8taktbUVRUVFnttauHAh6urqktOJREJtnMhcDFp/sdZvmMlv/q19lVq8hBa7EHYMi59YA0n7zb0W7xCUVofaeda2B+j13tOxCdY4He2YMxk/xJpzRyujNlaOFqMi61yeIy1vide9Yq1Xa34b7f7TYsQkGY+h5cPwc11q+9Rij6x5RrSYEuszNZNj1u5f6/2tra8917V7TYtJ0ZbPRNh5ijJhemPinMO8efOwfv16bN26FWVlZSnzJ0yYgAEDBqChoSH5XUtLC44cOYLKykrPbebn56OgoCDlQ0REROcn0xuTmpoarFmzBhs3bsSQIUOScSOxWAyDBg1CLBbD3LlzUVdXh+HDh6OgoAD33nsvKisrQ/tFDhEREfVdpobJ8uXLAXR/pbdy5crk69TFixcjNzcXM2fOREdHB6ZMmYKnn346lMISERFR35bjnHPZLsTZEokEYrFY2mUmT56cdr61P9gP2Q8u+w5l3IvWby77xbXxDLSYFet4QT3BGtfS2NiYdnvl5eUp09Y8DJnkFLGOz6PFFgSNVbCO5aHFW3itb43JsuYZWb16dfeCnmX06NFp59fW1qZMa7EKkp94qqBjHmn7tMYOWc9JJrRj1JaX19rSpUvTri+f29Zxo4LeG5msI49RG5/Legza8tq9pcWoAPZrUcuZoz23gf+mEgkSlsGxcoiIiCgy2DAhIiKiyGDDhIiIiCKjV8aYaH3SWj+i7LN+6KGH1G0EzXtgzS+RSf9pOn5yjFh/v671l2rLa/ktpk2blnZ9ydpnHYag8RjWcSqs141WHi9aPIW1DNbzbB3rKox4DO3asF5LQc+LNi6M5GesHCt5DFoODS2WSMaMaax1lMkYL0Gfs9b1g+7Pem/6WcYa96LdzwBjTIiIiKgPYcOEiIiIIoMNEyIiIoqMXhljYs1vkUn8hpaPImgeEbm+/D160DFYrH3mfrahjV1hzc2wf//+tMtrsUQa6zkC7LFEQVn7dyVr7gmv47PePzKXghbfEPQ8B43D8UNbJ2i8kjV+S8tXoV3bXrlb5HnSxnHRxvOR+9RiD7TzHDQux+uYNVrcilZnQePygsZHhZHLRdLKoN3PAGNMiIiIqA9hw4SIiIgigw0TIiIiigzTIH5RZe2b9NMHrfUdyv5XmXtB63e3jsWhjdUj96fxOj5rPIP19/BWQXPHyDrR8jB4rRO0TzhoXhJr7pZMWK9N67UWdP8ard/fqw6t1651rBvJet79jIFiFTR/hbY9P/kt0gk6Vo7kdR3JfVjHqtGEXceaMOLmrGX2E2MSFN+YEBERUWSwYUJERESRwYYJERERRUavjDHR4jvmzJmTMi37a2UfWia/+V+yZEnaMkpamYPGHli3nwlrn6+ctsYmhDHuS7r9+6kT6/g/Whm1/l5rngRrnfjpd5e0Ywwag2KNTbJey173t3XcFXl/aXEsQeMl5DNs1apVSMfPvWaNtdOek/IZo7HGd8hjkHUi5/t5HmjjpMk6keddlln7PxA0P441T4mfnFzac1D7X3Iu8I0JERERRQYbJkRERBQZbJgQERFRZLBhQkRERJHRKwfxmzZtWsq0FmBoDUD0WiaToKOzyeReWsImLZDLmgwtkzrQtqmVSQsIbGxsTLu9xYsXp0zLAMCwk1x5LaPRAj+DDgqYSRKpdOv7qbOgwahyWku8NX/+/JRpeW9o+9cC072OWX6nBXYGHfgsaCI9beBErTxetGeENeDXOlij3L8MbtUCU7Vz5nXdWIOeteemlnRS27/1eWNNlua1D2ugttym9twGOIgfERER9SFsmBAREVFksGFCREREkdErE6xZE/9o/b1+BpvTkv8ETS4m+x7lMcn+Xpk0zjromFffpuzjletoA4tp9WrtT5UxJVryMS0pnp9EYFq/u6SdV61OrLFEWv+w1ifeE/FVWiyCFmMil9cSaWkD3snz7lV+GfMly6DN1+pdez4EHSxOri/L63WvytgbOW2NNZD1rMWYyP1Z46Os17rX9rVltFhA7X603q8aa2I+r2s9aKLKMJJzWvGNCREREUUGGyZEREQUGWyYEBERUWT0yjwm5eXlaedb+2/9xFtYc2hY+/q1eAgtpiTowGleZdL6L7XBobQBsZYuXZq2jPI8a33M1vMuywPoA1Zpxxw0xkOLt5DXQdBzBOi5GOQ6Yee3kHlMtHguSTtmP33k8pitcWphD+JnzXviJ8+JNfeRxhpLJPOYaDEk2nnXyutnUL+gzxTtWrPmq9Hi6LTyeNWJ9dqT5PyNGzemXR5gHhMiIiLqQ0wNk+XLl+Pyyy9HQUEBCgoKUFlZiRdeeCE5/9SpU6ipqcGIESMwePBgzJw5E62traEXmoiIiPomU8OkpKQEjz76KJqbm7Fnzx5cf/31mDZtGl5//XUAwIIFC7Bp0yasW7cOjY2NOHbsGGbMmNEjBSciIqI+yAU0bNgw9+tf/9q1tbW5AQMGuHXr1iXnvfHGGw6Aa2pq8r299vZ2B4Affvjhhx9++OmFn/b29kDtioxjTM6cOYO1a9fi5MmTqKysRHNzMzo7O1FVVZVcZty4cSgtLUVTU9MnbqejowOJRCLlQ0REROcnc8Pktddew+DBg5Gfn4+7774b69evx6WXXop4PI68vLxuEbyFhYWIx+OfuL36+nrEYrHkZ9SoUeaDICIior7B3DD5/Oc/j3379mHnzp245557UF1djYMHD2ZcgIULF6K9vT35OXr0aMbbIiIiot7NPFZOXl4ePvvZzwIAJkyYgN27d2Pp0qW49dZbcfr0abS1taW8NWltbUVRUdEnbi8/Px/5+fn2khMREVGfEziPSVdXFzo6OjBhwgQMGDAADQ0NyXktLS04cuQIKisrg+6GiIiIzgOmNyYLFy7E1KlTUVpaihMnTmDNmjXYtm0bXnrpJcRiMcydOxd1dXUYPnw4CgoKcO+996KyshLXXHNNT5WfiIiI+hBTw+T48eOYPXs23n33XcRiMVx++eV46aWX8NWvfhUAsHjxYuTm5mLmzJno6OjAlClT8PTTT5sK5KKVIZ+IiIgMgv4fj9xYOW+//TZ/mUNERNRLHT16FCUlJRmvH7mGSVdXF44dOwbnHEpLS3H06NFAgwGd7xKJBEaNGsV6DIB1GBzrMBysx+BYh8F9Uh0653DixAkUFxcjNzfzEFbzr3J6Wm5uLkpKSpKJ1j4el4eCYT0GxzoMjnUYDtZjcKzD4LzqMBaLBd4uRxcmIiKiyGDDhIiIiCIjsg2T/Px8PPjgg0y+FhDrMTjWYXCsw3CwHoNjHQbX03UYueBXIiIiOn9F9o0JERERnX/YMCEiIqLIYMOEiIiIIoMNEyIiIoqMyDZMnnrqKYwZMwYDBw5ERUUFdu3ale0iRVZ9fT2uuuoqDBkyBCNHjsT06dPR0tKSssypU6dQU1ODESNGYPDgwZg5cyZaW1uzVOLoe/TRR5GTk4Pa2trkd6xDf9555x1861vfwogRIzBo0CCMHz8ee/bsSc53zuGBBx7AxRdfjEGDBqGqqgqHDh3KYomj5cyZM1i0aBHKysowaNAgfOYzn8GPf/zjlPFHWIeptm/fjptvvhnFxcXIycnBhg0bUub7qa/3338fs2bNQkFBAYYOHYq5c+figw8+OIdHkX3p6rGzsxP33Xcfxo8fjwsvvBDFxcWYPXs2jh07lrKNMOoxkg2T5557DnV1dXjwwQexd+9elJeXY8qUKTh+/Hi2ixZJjY2NqKmpwY4dO7BlyxZ0dnbixhtvxMmTJ5PLLFiwAJs2bcK6devQ2NiIY8eOYcaMGVksdXTt3r0bv/zlL3H55ZenfM861P3rX//CpEmTMGDAALzwwgs4ePAgfv7zn2PYsGHJZR5//HEsW7YMK1aswM6dO3HhhRdiypQpOHXqVBZLHh2PPfYYli9fjieffBJvvPEGHnvsMTz++ON44oknksuwDlOdPHkS5eXleOqppzzn+6mvWbNm4fXXX8eWLVuwefNmbN++HXfddde5OoRISFePH374Ifbu3YtFixZh7969eP7559HS0oJbbrklZblQ6tFF0NVXX+1qamqS02fOnHHFxcWuvr4+i6XqPY4fP+4AuMbGRuecc21tbW7AgAFu3bp1yWXeeOMNB8A1NTVlq5iRdOLECTd27Fi3ZcsWN3nyZDd//nznHOvQr/vuu89de+21nzi/q6vLFRUVuZ/97GfJ79ra2lx+fr77/e9/fy6KGHk33XSTu/POO1O+mzFjhps1a5ZzjnWoAeDWr1+fnPZTXwcPHnQA3O7du5PLvPDCCy4nJ8e9884756zsUSLr0cuuXbscAPfWW28558Krx8i9MTl9+jSam5tRVVWV/C43NxdVVVVoamrKYsl6j/b2dgDA8OHDAQDNzc3o7OxMqdNx48ahtLSUdSrU1NTgpptuSqkrgHXo1x//+EdMnDgR3/jGNzBy5EhceeWV+NWvfpWcf/jwYcTj8ZR6jMViqKioYD3+z5e+9CU0NDTgzTffBADs378fr7zyCqZOnQqAdWjlp76ampowdOhQTJw4MblMVVUVcnNzsXPnznNe5t6ivb0dOTk5GDp0KIDw6jFyg/i99957OHPmDAoLC1O+LywsxF/+8pcslar36OrqQm1tLSZNmoTLLrsMABCPx5GXl5e8eD5WWFiIeDyehVJG09q1a7F3717s3r272zzWoT9/+9vfsHz5ctTV1eEHP/gBdu/eje9+97vIy8tDdXV1sq687m/W43/df//9SCQSGDduHPr164czZ87gkUcewaxZswCAdWjkp77i8ThGjhyZMr9///4YPnw46/QTnDp1Cvfddx9uv/325EB+YdVj5BomFExNTQ0OHDiAV155JdtF6VWOHj2K+fPnY8uWLRg4cGC2i9NrdXV1YeLEifjpT38KALjyyitx4MABrFixAtXV1VkuXe/whz/8Ac8++yzWrFmDL3zhC9i3bx9qa2tRXFzMOqRI6OzsxDe/+U0457B8+fLQtx+5rpyLLroI/fr16/Zrh9bWVhQVFWWpVL3DvHnzsHnzZrz88ssoKSlJfl9UVITTp0+jra0tZXnW6f81Nzfj+PHj+OIXv4j+/fujf//+aGxsxLJly9C/f38UFhayDn24+OKLcemll6Z8d8kll+DIkSMAkKwr3t+f7Hvf+x7uv/9+3HbbbRg/fjy+/e1vY8GCBaivrwfAOrTyU19FRUXdflzx73//G++//z7rVPi4UfLWW29hy5YtybclQHj1GLmGSV5eHiZMmICGhobkd11dXWhoaEBlZWUWSxZdzjnMmzcP69evx9atW1FWVpYyf8KECRgwYEBKnba0tODIkSOs0/+54YYb8Nprr2Hfvn3Jz8SJEzFr1qzk36xD3aRJk7r9VP3NN9/E6NGjAQBlZWUoKipKqcdEIoGdO3eyHv/nww8/RG5u6qO5X79+6OrqAsA6tPJTX5WVlWhra0Nzc3Nyma1bt6KrqwsVFRXnvMxR9XGj5NChQ/jTn/6EESNGpMwPrR4zCNbtcWvXrnX5+flu1apV7uDBg+6uu+5yQ4cOdfF4PNtFi6R77rnHxWIxt23bNvfuu+8mPx9++GFymbvvvtuVlpa6rVu3uj179rjKykpXWVmZxVJH39m/ynGOdejHrl27XP/+/d0jjzziDh065J599ll3wQUXuN/97nfJZR599FE3dOhQt3HjRvfnP//ZTZs2zZWVlbmPPvooiyWPjurqavepT33Kbd682R0+fNg9//zz7qKLLnLf//73k8uwDlOdOHHCvfrqq+7VV191ANwvfvEL9+qrryZ/LeKnvr72ta+5K6+80u3cudO98sorbuzYse7222/P1iFlRbp6PH36tLvllltcSUmJ27dvX8r/mo6OjuQ2wqjHSDZMnHPuiSeecKWlpS4vL89dffXVbseOHdkuUmQB8PysXLkyucxHH33kvvOd77hhw4a5Cy64wH3961937777bvYK3QvIhgnr0J9Nmza5yy67zOXn57tx48a5Z555JmV+V1eXW7RokSssLHT5+fnuhhtucC0tLVkqbfQkEgk3f/58V1pa6gYOHOg+/elPux/+8IcpD3/WYaqXX37Z8xlYXV3tnPNXX//85z/d7bff7gYPHuwKCgrcHXfc4U6cOJGFo8medPV4+PDhT/xf8/LLLye3EUY95jh3VjpBIiIioiyKXIwJERERnb/YMCEiIqLIYMOEiIiIIoMNEyIiIooMNkyIiIgoMtgwISIioshgw4SIiIgigw0TIiIiigw2TIiIiCgy2DAhIiKiyGDDhIiIiCKDDRMiIiKKjP8AhoB1uxYGu3IAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116,\n",
      "         0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "        [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
      "         0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
      "       
Download .txt
gitextract_jcwmbccq/

├── .gitignore
├── CITATION.cff
├── d3pm_runner.py
├── d3pm_runner_cifar10.py
├── dit.py
├── lm.py
├── lm_deepspeed.py
├── readme.md
├── run_multigpu.sh
└── test.ipynb
Download .txt
SYMBOL INDEX (62 symbols across 4 files)

FILE: d3pm_runner.py
  class DummyX0Model (line 39) | class DummyX0Model(nn.Module):
    method __init__ (line 41) | def __init__(self, n_channel: int, N: int = 16) -> None:
    method forward (line 72) | def forward(self, x, t, cond) -> torch.Tensor:
  class D3PM (line 134) | class D3PM(nn.Module):
    method __init__ (line 135) | def __init__(
    method _at (line 193) | def _at(self, a, t, x):
    method q_posterior_logits (line 200) | def q_posterior_logits(self, x_0, x_t, t):
    method vb (line 237) | def vb(self, dist1, dist2):
    method q_sample (line 249) | def q_sample(self, x_0, t, noise):
    method model_predict (line 256) | def model_predict(self, x_0, t, cond):
    method forward (line 266) | def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch...
    method p_sample (line 299) | def p_sample(self, x, t, cond, noise):
    method sample (line 314) | def sample(self, x, cond=None):
    method sample_with_image_sequence (line 323) | def sample_with_image_sequence(self, x, cond=None, stride=10):

FILE: dit.py
  function modulate (line 10) | def modulate(x, shift, scale):
  class TimestepEmbedder (line 14) | class TimestepEmbedder(nn.Module):
    method __init__ (line 15) | def __init__(self, hidden_size, frequency_embedding_size=256):
    method timestep_embedding (line 25) | def timestep_embedding(t, dim, max_period=10000):
    method forward (line 38) | def forward(self, t):
  class LabelEmbedder (line 46) | class LabelEmbedder(nn.Module):
    method __init__ (line 47) | def __init__(self, num_classes, hidden_size, dropout_prob):
    method token_drop (line 56) | def token_drop(self, labels, force_drop_ids=None):
    method forward (line 66) | def forward(self, labels, train, force_drop_ids=None):
  class Attention (line 74) | class Attention(nn.Module):
    method __init__ (line 75) | def __init__(self, dim, n_heads):
    method reshape_for_broadcast (line 91) | def reshape_for_broadcast(freqs_cis, x):
    method apply_rotary_emb (line 99) | def apply_rotary_emb(xq, xk, freqs_cis):
    method forward (line 107) | def forward(self, x, freqs_cis):
  class FeedForward (line 134) | class FeedForward(nn.Module):
    method __init__ (line 135) | def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=No...
    method _forward_silu_gating (line 146) | def _forward_silu_gating(self, x1, x3):
    method forward (line 149) | def forward(self, x):
  class TransformerBlock (line 153) | class TransformerBlock(nn.Module):
    method __init__ (line 154) | def __init__(
    method forward (line 182) | def forward(self, x, freqs_cis, adaln_input=None):
  class FinalLayer (line 201) | class FinalLayer(nn.Module):
    method __init__ (line 202) | def __init__(self, hidden_size, patch_size, out_channels):
    method forward (line 216) | def forward(self, x, c):
  class DDiT_Llama (line 223) | class DDiT_Llama(nn.Module):
    method __init__ (line 225) | def __init__(
    method forward (line 262) | def forward(self, x, t, cond=None):
  class DiT_Llama (line 281) | class DiT_Llama(nn.Module):
    method __init__ (line 282) | def __init__(
    method unpatchify (line 338) | def unpatchify(self, x):
    method patchify (line 347) | def patchify(self, x):
    method forward (line 360) | def forward(self, x, t, y):
    method forward_with_cfg (line 394) | def forward_with_cfg(self, x, t, y, cfg_scale):
    method precompute_freqs_cis (line 405) | def precompute_freqs_cis(dim, end, theta=10000.0):
  function DiT_Llama_600M_patch2 (line 413) | def DiT_Llama_600M_patch2(**kwargs):
  function DiT_Llama_3B_patch2 (line 417) | def DiT_Llama_3B_patch2(**kwargs):

FILE: lm.py
  class WikiTextDataset (line 14) | class WikiTextDataset(Dataset):
    method __init__ (line 15) | def __init__(self, tokenizer=None, type_path="train", max_length=512, ...
    method __len__ (line 25) | def __len__(self):
    method __getitem__ (line 30) | def __getitem__(self, idx):

FILE: lm_deepspeed.py
  class WikiTextDataset (line 23) | class WikiTextDataset(Dataset):
    method __init__ (line 24) | def __init__(
    method __len__ (line 38) | def __len__(self):
    method __getitem__ (line 41) | def __getitem__(self, idx):
  function _z3_params_to_fetch (line 65) | def _z3_params_to_fetch(param_list):
  function save_zero_three_model (line 73) | def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
  function set_seed (line 100) | def set_seed(seed=42):
  function main (line 120) | def main(
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (227K chars).
[
  {
    "path": ".gitignore",
    "chars": 65,
    "preview": "presetup.sh\nsetup.sh\ncu122py310\ndata\ntest\nwandb\nrun.sh\n*.pyc\nckpt"
  },
  {
    "path": "CITATION.cff",
    "chars": 420,
    "preview": "cff-version: 1.2.0\nmessage: \"Citations would be appreciated if you end up using this tool! I currently go by Simo Ryu\"\na"
  },
  {
    "path": "d3pm_runner.py",
    "chars": 15018,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom PIL import Image\nfrom torch.utils.data import DataLoader\nfrom"
  },
  {
    "path": "d3pm_runner_cifar10.py",
    "chars": 3983,
    "preview": "import numpy as np\nimport torch\nfrom PIL import Image\nfrom torch.utils.data import DataLoader\nfrom torchvision import tr"
  },
  {
    "path": "dit.py",
    "chars": 14166,
    "preview": "# Code heavilty based on https://github.com/Alpha-VLLM/LLaMA2-Accessory\n\nimport math\n\nimport torch\nimport torch.nn as nn"
  },
  {
    "path": "lm.py",
    "chars": 5494,
    "preview": "import math\n\nimport torch\nfrom datasets import load_dataset\nfrom torch.utils.data import DataLoader, Dataset\nfrom tqdm i"
  },
  {
    "path": "lm_deepspeed.py",
    "chars": 10476,
    "preview": "import math\nimport os\nimport random\n\nimport click\nimport deepspeed\nimport numpy as np\nimport torch\nfrom datasets import "
  },
  {
    "path": "readme.md",
    "chars": 2601,
    "preview": "<p align=\"center\">\n  <img src=\"contents/output.gif\" alt=\"large\" width=\"400\">\n  <img src=\"contents/cifar_best.gif\" alt=\"l"
  },
  {
    "path": "run_multigpu.sh",
    "chars": 112,
    "preview": "export WORLD_SIZE=$(nvidia-smi -L | wc -l)\ndeepspeed --num_gpus $WORLD_SIZE lm_deepspeed.py --learning_rate 1e-4"
  },
  {
    "path": "test.ipynb",
    "chars": 168127,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  }
]

About this extraction

This page contains the full source code of the cloneofsimo/d3pm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (215.3 KB), approximately 112.3k tokens, and a symbol index with 62 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!