Repository: cloneofsimo/d3pm
Branch: main
Commit: 3ceb63725ee2
Files: 10
Total size: 215.3 KB
Directory structure:
gitextract_jcwmbccq/
├── .gitignore
├── CITATION.cff
├── d3pm_runner.py
├── d3pm_runner_cifar10.py
├── dit.py
├── lm.py
├── lm_deepspeed.py
├── readme.md
├── run_multigpu.sh
└── test.ipynb
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
presetup.sh
setup.sh
cu122py310
data
test
wandb
run.sh
*.pyc
ckpt
================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "Citations would be appreciated if you end up using this tool! I currently go by Simo Ryu"
authors:
- family-names: "Ryu"
given-names: "Simo"
orcid: "https://orcid.org/0009-0008-0017-2677"
title: "Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch"
version: 0.0.1
date-released: 2024-04
url: "https://github.com/cloneofsimo/d3pm"
================================================
FILE: d3pm_runner.py
================================================
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm import tqdm
blk = lambda ic, oc: nn.Sequential(
nn.Conv2d(ic, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
)
blku = lambda ic, oc: nn.Sequential(
nn.Conv2d(ic, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.Conv2d(oc, oc, 5, padding=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
nn.ConvTranspose2d(oc, oc, 2, stride=2),
nn.GroupNorm(oc // 8, oc),
nn.LeakyReLU(),
)
class DummyX0Model(nn.Module):
def __init__(self, n_channel: int, N: int = 16) -> None:
super(DummyX0Model, self).__init__()
self.down1 = blk(n_channel, 16)
self.down2 = blk(16, 32)
self.down3 = blk(32, 64)
self.down4 = blk(64, 512)
self.down5 = blk(512, 512)
self.up1 = blku(512, 512)
self.up2 = blku(512 + 512, 64)
self.up3 = blku(64, 32)
self.up4 = blku(32, 16)
self.convlast = blk(16, 16)
self.final = nn.Conv2d(16, N * n_channel, 1, bias=False)
self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8)
self.cond_embedding_1 = nn.Embedding(10, 16)
self.cond_embedding_2 = nn.Embedding(10, 32)
self.cond_embedding_3 = nn.Embedding(10, 64)
self.cond_embedding_4 = nn.Embedding(10, 512)
self.cond_embedding_5 = nn.Embedding(10, 512)
self.cond_embedding_6 = nn.Embedding(10, 64)
self.temb_1 = nn.Linear(32, 16)
self.temb_2 = nn.Linear(32, 32)
self.temb_3 = nn.Linear(32, 64)
self.temb_4 = nn.Linear(32, 512)
self.N = N
def forward(self, x, t, cond) -> torch.Tensor:
x = (2 * x.float() / self.N) - 1.0
t = t.float().reshape(-1, 1) / 1000
t_features = [torch.sin(t * 3.1415 * 2**i) for i in range(16)] + [
torch.cos(t * 3.1415 * 2**i) for i in range(16)
]
tx = torch.cat(t_features, dim=1).to(x.device)
t_emb_1 = self.temb_1(tx).unsqueeze(-1).unsqueeze(-1)
t_emb_2 = self.temb_2(tx).unsqueeze(-1).unsqueeze(-1)
t_emb_3 = self.temb_3(tx).unsqueeze(-1).unsqueeze(-1)
t_emb_4 = self.temb_4(tx).unsqueeze(-1).unsqueeze(-1)
cond_emb_1 = self.cond_embedding_1(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_2 = self.cond_embedding_2(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_3 = self.cond_embedding_3(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_4 = self.cond_embedding_4(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_5 = self.cond_embedding_5(cond).unsqueeze(-1).unsqueeze(-1)
cond_emb_6 = self.cond_embedding_6(cond).unsqueeze(-1).unsqueeze(-1)
x1 = self.down1(x) + t_emb_1 + cond_emb_1
x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2 + cond_emb_2
x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3 + cond_emb_3
x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4 + cond_emb_4
x5 = self.down5(nn.functional.avg_pool2d(x4, 2))
x5 = (
self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2))
.transpose(1, 2)
.reshape(x5.shape)
)
y = self.up1(x5) + cond_emb_5
y = (
self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2))
.transpose(1, 2)
.reshape(y.shape)
)
y = self.up2(torch.cat([x4, y], dim=1)) + cond_emb_6
y = (
self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2))
.transpose(1, 2)
.reshape(y.shape)
)
y = self.up3(y)
y = self.up4(y)
y = self.convlast(y)
y = self.final(y)
# reshape to B, C, H, W, N
y = (
y.reshape(y.shape[0], -1, self.N, *x.shape[2:])
.transpose(2, -1)
.contiguous()
)
return y
class D3PM(nn.Module):
def __init__(
self,
x0_model: nn.Module,
n_T: int,
num_classes: int = 10,
forward_type="uniform",
hybrid_loss_coeff=0.001,
) -> None:
super(D3PM, self).__init__()
self.x0_model = x0_model
self.n_T = n_T
self.hybrid_loss_coeff = hybrid_loss_coeff
steps = torch.arange(n_T + 1, dtype=torch.float64) / n_T
alpha_bar = torch.cos((steps + 0.008) / 1.008 * torch.pi / 2)
self.beta_t = torch.minimum(
1 - alpha_bar[1:] / alpha_bar[:-1], torch.ones_like(alpha_bar[1:]) * 0.999
)
# self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)]
self.eps = 1e-6
self.num_classses = num_classes
q_onestep_mats = []
q_mats = [] # these are cumulative
for beta in self.beta_t:
if forward_type == "uniform":
mat = torch.ones(num_classes, num_classes) * beta / num_classes
mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)
q_onestep_mats.append(mat)
else:
raise NotImplementedError
q_one_step_mats = torch.stack(q_onestep_mats, dim=0)
q_one_step_transposed = q_one_step_mats.transpose(
1, 2
) # this will be used for q_posterior_logits
q_mat_t = q_onestep_mats[0]
q_mats = [q_mat_t]
for idx in range(1, self.n_T):
q_mat_t = q_mat_t @ q_onestep_mats[idx]
q_mats.append(q_mat_t)
q_mats = torch.stack(q_mats, dim=0)
self.logit_type = "logit"
# register
self.register_buffer("q_one_step_transposed", q_one_step_transposed)
self.register_buffer("q_mats", q_mats)
assert self.q_mats.shape == (
self.n_T,
num_classes,
num_classes,
), self.q_mats.shape
def _at(self, a, t, x):
# t is 1-d, x is integer value of 0 to num_classes - 1
bs = t.shape[0]
t = t.reshape((bs, *[1] * (x.dim() - 1)))
# out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]
return a[t - 1, x, :]
def q_posterior_logits(self, x_0, x_t, t):
# if t == 1, this means we return the L_0 loss, so directly try to x_0 logits.
# otherwise, we return the L_{t-1} loss.
# Also, we never have t == 0.
# if x_0 is integer, we convert it to one-hot.
if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:
x_0_logits = torch.log(
torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps
)
else:
x_0_logits = x_0.clone()
assert x_0_logits.shape == x_t.shape + (self.num_classses,), print(
f"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}"
)
# Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that.
# fact1 is "guess of x_{t-1}" from x_t
# fact2 is "guess of x_{t-1}" from x_0
fact1 = self._at(self.q_one_step_transposed, t, x_t)
softmaxed = torch.softmax(x_0_logits, dim=-1) # bs, ..., num_classes
qmats2 = self.q_mats[t - 2].to(dtype=softmaxed.dtype)
# bs, num_classes, num_classes
fact2 = torch.einsum("b...c,bcd->b...d", softmaxed, qmats2)
out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)
t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))
bc = torch.where(t_broadcast == 1, x_0_logits, out)
return bc
def vb(self, dist1, dist2):
# flatten dist1 and dist2
dist1 = dist1.flatten(start_dim=0, end_dim=-2)
dist2 = dist2.flatten(start_dim=0, end_dim=-2)
out = torch.softmax(dist1 + self.eps, dim=-1) * (
torch.log_softmax(dist1 + self.eps, dim=-1)
- torch.log_softmax(dist2 + self.eps, dim=-1)
)
return out.sum(dim=-1).mean()
def q_sample(self, x_0, t, noise):
# forward process, x_0 is the clean input.
logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)
noise = torch.clip(noise, self.eps, 1.0)
gumbel_noise = -torch.log(-torch.log(noise))
return torch.argmax(logits + gumbel_noise, dim=-1)
def model_predict(self, x_0, t, cond):
# this part exists because in general, manipulation of logits from model's logit
# so they are in form of x_0's logit might be independent to model choice.
# for example, you can convert 2 * N channel output of model output to logit via get_logits_from_logistic_pars
# they introduce at appendix A.8.
predicted_x0_logits = self.x0_model(x_0, t, cond)
return predicted_x0_logits
def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch.Tensor:
"""
Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model.
x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1
"""
t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)
x_t = self.q_sample(
x, t, torch.rand((*x.shape, self.num_classses), device=x.device)
)
# x_t is same shape as x
assert x_t.shape == x.shape, print(
f"x_t.shape: {x_t.shape}, x.shape: {x.shape}"
)
# we use hybrid loss.
predicted_x0_logits = self.model_predict(x_t, t, cond)
# based on this, we first do vb loss.
true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)
pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)
vb_loss = self.vb(true_q_posterior_logits, pred_q_posterior_logits)
predicted_x0_logits = predicted_x0_logits.flatten(start_dim=0, end_dim=-2)
x = x.flatten(start_dim=0, end_dim=-1)
ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)
return self.hybrid_loss_coeff * vb_loss + ce_loss, {
"vb_loss": vb_loss.detach().item(),
"ce_loss": ce_loss.detach().item(),
}
def p_sample(self, x, t, cond, noise):
predicted_x0_logits = self.model_predict(x, t, cond)
pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)
noise = torch.clip(noise, self.eps, 1.0)
not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))
gumbel_noise = -torch.log(-torch.log(noise))
sample = torch.argmax(
pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1
)
return sample
def sample(self, x, cond=None):
for t in reversed(range(1, self.n_T)):
t = torch.tensor([t] * x.shape[0], device=x.device)
x = self.p_sample(
x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
)
return x
def sample_with_image_sequence(self, x, cond=None, stride=10):
steps = 0
images = []
for t in reversed(range(1, self.n_T)):
t = torch.tensor([t] * x.shape[0], device=x.device)
x = self.p_sample(
x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device)
)
steps += 1
if steps % stride == 0:
images.append(x)
# if last step is not divisible by stride, we add the last image.
if steps % stride != 0:
images.append(x)
return images
if __name__ == "__main__":
N = 2 # number of classes for discretized state per pixel
d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes=N, hybrid_loss_coeff=0.0).cuda()
print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
dataset = MNIST(
"./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Pad(2),
]
),
)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32)
optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=1e-3)
d3pm.train()
n_epoch = 400
device = "cuda"
global_step = 0
for i in range(n_epoch):
pbar = tqdm(dataloader)
loss_ema = None
for x, cond in pbar:
optim.zero_grad()
x = x.to(device)
cond = cond.to(device)
# discritize x to N bins
x = (x * (N - 1)).round().long().clamp(0, N - 1)
loss, info = d3pm(x, cond)
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.1)
with torch.no_grad():
param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.99 * loss_ema + 0.01 * loss.item()
pbar.set_description(
f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
)
optim.step()
global_step += 1
if global_step % 300 == 1:
d3pm.eval()
with torch.no_grad():
cond = torch.arange(0, 4).cuda() % 10
init_noise = torch.randint(0, N, (4, 1, 32, 32)).cuda()
images = d3pm.sample_with_image_sequence(
init_noise, cond, stride=40
)
# image sequences to gif
gif = []
for image in images:
x_as_image = make_grid(image.float() / (N - 1), nrow=2)
img = x_as_image.permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
gif.append(Image.fromarray(img))
gif[0].save(
f"contents/sample_{global_step}.gif",
save_all=True,
append_images=gif[1:],
duration=100,
loop=0,
)
last_img = gif[-1]
last_img.save(f"contents/sample_{global_step}_last.png")
d3pm.train()
================================================
FILE: d3pm_runner_cifar10.py
================================================
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid
from tqdm import tqdm
import wandb
from d3pm_runner import D3PM
from dit import DiT_Llama
if __name__ == "__main__":
wandb.init(project="d3pm_cifar10")
N = 8 # number of classes for discretized state per pixel
d3pm = D3PM(
DiT_Llama(3, N, dim=1024), 1000, num_classes=N, hybrid_loss_coeff=0.0
).cuda()
print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
dataset = CIFAR10(
"./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
),
)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=16)
optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=2e-5)
d3pm.train()
n_epoch = 4000
device = "cuda"
global_step = 0
for i in range(n_epoch):
pbar = tqdm(dataloader)
loss_ema = None
for x, cond in pbar:
optim.zero_grad()
x = x.to(device)
cond = cond.to(device)
# discritize x to N bins
x_cat = (x * (N - 1)).round().long().clamp(0, N - 1)
loss, info = d3pm(x_cat, cond)
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 5.0)
with torch.no_grad():
param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.99 * loss_ema + 0.01 * loss.item()
if global_step % 10 == 0:
wandb.log(
{
"train_loss": loss,
"train_grad_norm": norm,
"train_param_norm": param_norm,
}
)
pbar.set_description(
f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
)
optim.step()
global_step += 1
if global_step % 600 == 1:
d3pm.eval()
with torch.no_grad():
cond = torch.arange(0, 16).cuda() % 10
init_noise = torch.randint(0, N, (16, 3, 32, 32)).cuda()
images = d3pm.sample_with_image_sequence(
init_noise, cond, stride=40
)
# image sequences to gif
gif = []
for image in images:
x_from_dataloader = x_cat[:16].cpu() / (N - 1)
this_image = image.float().cpu() / (N - 1)
all_images = torch.cat([x_from_dataloader, this_image], dim=0)
x_as_image = make_grid(all_images, nrow=4)
img = x_as_image.permute(1, 2, 0).cpu().numpy()
img = (img * 255).astype(np.uint8)
gif.append(Image.fromarray(img))
gif[0].save(
f"contents/sample_{global_step}.gif",
save_all=True,
append_images=gif[1:],
duration=100,
loop=0,
)
last_img = gif[-1]
last_img.save(f"contents/sample_{global_step}_last.png")
# log images
wandb.log(
{
"sample": wandb.Image(last_img),
}
)
d3pm.train()
================================================
FILE: dit.py
================================================
# Code heavilty based on https://github.com/Alpha-VLLM/LLaMA2-Accessory
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
dtype=next(self.parameters()).dtype
)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = int(dropout_prob > 0)
self.embedding_table = nn.Embedding(
num_classes + use_cfg_embedding, hidden_size
)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
drop_ids = drop_ids.cuda()
drop_ids = drop_ids.to(labels.device)
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class Attention(nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.n_rep = 1
self.head_dim = dim // n_heads
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False)
self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim)
@staticmethod
def reshape_for_broadcast(freqs_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
@staticmethod
def apply_rotary_emb(xq, xk, freqs_cis):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out, xk_out
def forward(self, x, freqs_cis):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
dtype = xq.dtype
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim)
xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)
output = F.scaled_dot_product_attention(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
dropout_p=0.0,
is_causal=False,
).permute(0, 2, 1, 3)
output = output.flatten(-2)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
class TransformerBlock(nn.Module):
def __init__(
self,
layer_id,
dim,
n_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = Attention(dim, n_heads)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = nn.LayerNorm(dim, eps=norm_eps)
self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(dim, 1024), 6 * dim, bias=True),
)
def forward(self, x, freqs_cis, adaln_input=None):
if adaln_input is not None:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
)
x = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis
)
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
modulate(self.ffn_norm(x), shift_mlp, scale_mlp)
)
else:
x = x + self.attention(self.attention_norm(x), freqs_cis)
x = x + self.feed_forward(self.ffn_norm(x))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True),
)
# # init zero
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DDiT_Llama(nn.Module):
def __init__(
self,
N=256,
dim=512,
n_layers=5,
n_heads=16,
multiple_of=256,
ffn_dim_multiplier=None,
norm_eps=1e-5,
learn_gating=False,
):
super().__init__()
self.N = N
self.learn_gating = learn_gating
if self.learn_gating:
self.out_channel = N * 2
else:
self.out_channel = N
self.embedder = nn.Embedding(N, dim)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.layers = nn.ModuleList(
[
TransformerBlock(
layer_id,
dim,
n_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
)
for layer_id in range(n_layers)
]
)
self.final_layer = FinalLayer(dim, 1, self.out_channel)
self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 4096)
def forward(self, x, t, cond=None):
self.freqs_cis = self.freqs_cis.to(x.device)
x_onehot = torch.nn.functional.one_hot(x, self.N).to(
x.device, dtype=next(self.parameters()).dtype
)
x = self.embedder(x)
adaln_input = self.t_embedder(t)
for layer in self.layers:
x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
x = self.final_layer(x, adaln_input)
if self.learn_gating:
x, gate = x.chunk(2, dim=-1)
return x + x_onehot * (1 + gate).abs()
else:
return x + x_onehot
class DiT_Llama(nn.Module):
def __init__(
self,
in_channels=3,
N=8,
input_size=32,
patch_size=2,
dim=512,
n_layers=5,
n_heads=16,
multiple_of=256,
ffn_dim_multiplier=None,
norm_eps=1e-5,
class_dropout_prob=0.1,
num_classes=10,
learn_sigma=True,
):
super().__init__()
self.N = N
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = N * in_channels * 2
self.input_size = input_size
self.patch_size = patch_size
self.init_conv_seq = nn.Sequential(
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.GroupNorm(32, dim // 2),
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
nn.SiLU(),
nn.GroupNorm(32, dim // 2),
)
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
nn.init.constant_(self.x_embedder.bias, 0)
self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.y_embedder = LabelEmbedder(num_classes, min(dim, 1024), class_dropout_prob)
self.layers = nn.ModuleList(
[
TransformerBlock(
layer_id,
dim,
n_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
)
for layer_id in range(n_layers)
]
)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 4096)
def unpatchify(self, x):
c = self.out_channels
p = self.patch_size
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
x = x.view(
B,
C,
H // self.patch_size,
self.patch_size,
W // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def forward(self, x, t, y):
self.freqs_cis = self.freqs_cis.to(x.device)
x_onehot = torch.nn.functional.one_hot(x, self.N).float().to(x.device)
x = (2 * x.float() / (self.N - 1)) - 1.0
x = self.init_conv_seq(x)
x = self.patchify(x)
x = self.x_embedder(x)
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
adaln_input = t + y
for layer in self.layers:
x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x) # (N, out_channels, H, W)
x, gate = (
x.reshape(x.shape[0], -1, self.N * 2, *x.shape[2:])
.transpose(2, -1)
.contiguous()
).chunk(2, dim=-1)
return x + x_onehot * (1 + gate).abs()
# x = (x.reshape(x.shape[0], -1, self.N, *x.shape[2:])
# .transpose(2, -1)
# .contiguous()
# )
# return x
def forward_with_cfg(self, x, t, y, cfg_scale):
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, y)
eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
@staticmethod
def precompute_freqs_cis(dim, end, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def DiT_Llama_600M_patch2(**kwargs):
return DiT_Llama(patch_size=2, dim=256, n_layers=16, n_heads=32, **kwargs)
def DiT_Llama_3B_patch2(**kwargs):
return DiT_Llama(patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs)
if __name__ == "__main__":
model = DiT_Llama_600M_patch2()
model.eval()
x = torch.randint(0, 8, (4, 3, 32, 32))
t = torch.randint(0, 1000, (4,))
y = torch.randint(0, 10, (4,))
out = model(x, t, y)
print(out)
# cuda ver
model = model.cuda()
x = x.cuda()
t = t.cuda()
y = y.cuda()
out = model(x, t, y)
print(out)
================================================
FILE: lm.py
================================================
import math
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import get_scheduler
import wandb
from d3pm_runner import D3PM
from dit import DDiT_Llama
class WikiTextDataset(Dataset):
def __init__(self, tokenizer=None, type_path="train", max_length=512, debug=False):
if debug:
vernum = 2
else:
vernum = 103
self.vernum = vernum
self.dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train")
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return (
int(len(self.dataset) * 0.1) if (self.vernum == 103) else len(self.dataset)
)
def __getitem__(self, idx):
text = self.dataset[idx]["text"]
# logger.info(text)
if self.tokenizer is not None:
inputs = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
input_ids = inputs.input_ids.squeeze()
else:
# use byte encoding
seq = list(text.encode("utf-8"))
if len(seq) < self.max_length:
seq += [0] * (self.max_length - len(seq))
else:
seq = seq[: self.max_length]
input_ids = torch.tensor(seq, dtype=torch.long)
return {"input_ids": input_ids}
if __name__ == "__main__":
wandb.init(project="d3pm_wiki")
N = 256
max_length = 256
num_train_epochs = 5
d3pm = D3PM(
DDiT_Llama(N, dim=512, n_layers=6), 1000, num_classes=N, hybrid_loss_coeff=0.0
).cuda()
print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}")
dataset = WikiTextDataset(max_length=max_length, debug=False)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=8)
optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=2e-4)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optim,
num_warmup_steps=100,
num_training_steps=num_train_epochs * math.ceil(len(dataloader)),
)
d3pm.train()
device = "cuda"
global_step = 0
for i in range(num_train_epochs):
pbar = tqdm(dataloader)
loss_ema = None
for x in pbar:
optim.zero_grad()
x = x["input_ids"].to(device)
# discritize x to N bins
loss, info = d3pm(x)
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 5.0)
with torch.no_grad():
param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.99 * loss_ema + 0.01 * loss.item()
if global_step % 10 == 0:
wandb.log(
{
"train_loss": loss,
"train_grad_norm": norm,
"train_param_norm": param_norm,
}
)
pbar.set_description(
f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
)
optim.step()
lr_scheduler.step()
global_step += 1
if global_step % 600 == 1:
d3pm.eval()
with torch.no_grad():
init_noise = torch.randint(0, N, (16, max_length)).cuda()
outputs = d3pm.sample_with_image_sequence(
init_noise, None, stride=40
)
gen_outputs = []
total = 0
# back to sentence, byte to utf-8
for _i in range(16):
sent = outputs[-1][_i].cpu().tolist()
correctly_parsed = True
try:
sent = b"".join([bytes([i]) for i in sent]).decode("utf-8")
except:
# if there is error, just unicodec
correctly_parsed = False
sent = "".join([chr(i) for i in sent])
sent = (
f"[{_i}] Sample Correctly parsed: "
+ str(correctly_parsed)
+ "\n"
+ sent
)
total += 1 if correctly_parsed else 0
gen_outputs.append(sent)
print(sent)
# make a nice html to show the generated outputs
html_formatted = "
".join(gen_outputs)
# log text
wandb.log(
{
"generated_text": wandb.Html(html_formatted),
"correctly_parsed": total,
}
)
d3pm.train()
if global_step % 3000 == 1:
torch.save(d3pm.state_dict(), f"ckpt/d3pm_wiki_{global_step}.pth")
print(f"Model saved at {global_step}")
================================================
FILE: lm_deepspeed.py
================================================
import math
import os
import random
import click
import deepspeed
import numpy as np
import torch
from datasets import load_dataset
from deepspeed import get_accelerator
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.utils import logger
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import default_data_collator, get_scheduler
import wandb
from d3pm_runner import D3PM
from dit import DDiT_Llama
class WikiTextDataset(Dataset):
def __init__(
self, tokenizer=None, type_path="train", max_seq_length=512, debug=False
):
if debug:
self.dataset = load_dataset("wikitext", f"wikitext-2-raw-v1", split="test")
else:
self.dataset = load_dataset(
"wikimedia/wikipedia", "20231101.en", split="train"
)
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
text = self.dataset[idx]["text"]
# logger.info(text)
if self.tokenizer is not None:
inputs = self.tokenizer(
text,
max_length=self.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
input_ids = inputs.input_ids.squeeze()
else:
# use byte encoding
seq = list(text.encode("utf-8"))
if len(seq) < self.max_seq_length:
seq += [0] * (self.max_seq_length - len(seq))
else:
seq = seq[: self.max_seq_length]
input_ids = torch.tensor(seq, dtype=torch.long)
return {"input_ids": input_ids}
def _z3_params_to_fetch(param_list):
return [
p
for p in param_list
if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
zero_stage_3 = zero_stage == 3
os.makedirs(save_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, "ds_id"):
with deepspeed.zero.GatheredParameters(
_z3_params_to_fetch([v]), enabled=zero_stage_3
):
v_p = v.data.cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p
if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@click.command()
@click.option("--local_rank", default=-1, help="Local rank")
@click.option("--max_seq_length", default=256, help="Max sequence length")
@click.option("--num_train_epochs", default=5, help="Number of training epochs")
@click.option("--learning_rate", default=1e-4, help="Learning rate")
@click.option("--offload", default=False, help="Offload")
@click.option("--train_batch_size", default=1024, help="Train batch size")
@click.option(
"--per_device_train_batch_size", default=64, help="Per device train batch size"
)
@click.option("--zero_stage", default=2, help="Zero stage")
@click.option("--seed", default=42, help="Seed")
@click.option("--run_name", default=None, help="Run name")
def main(
local_rank,
max_seq_length=256,
num_train_epochs=5,
learning_rate=1e-4,
offload=False,
train_batch_size=512,
per_device_train_batch_size=64,
zero_stage=2,
seed=42,
run_name=None,
):
# first, set the seed
set_seed(seed)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
if run_name is None:
run_name = f"LR:{learning_rate}_max_seq_length:{max_seq_length}_num_train_epochs:{num_train_epochs}_offload:{offload}"
if local_rank == -1:
device = torch.device(get_accelerator().device_name())
else:
get_accelerator().set_device(local_rank)
device = torch.device(get_accelerator().device_name(), local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
deepspeed.init_distributed()
offload_device = "cpu" if offload else "none"
ds_config = {
"train_micro_batch_size_per_gpu": per_device_train_batch_size,
"train_batch_size": train_batch_size,
"zero_optimization": {
"stage": zero_stage,
"offload_param": {"device": offload_device},
"offload_optimizer": {"device": offload_device},
"stage3_param_persistence_threshold": 1e4,
"stage3_max_live_parameters": 3e7,
"stage3_prefetch_bucket_size": 3e7,
"memory_efficient_linear": False,
},
"bfloat16": {"enabled": True},
"gradient_clipping": 1.0,
}
torch.distributed.barrier()
global_rank = torch.distributed.get_rank()
##### DEFINE model, dataset, sampler, dataloader, optim, schedular
N = 256
with deepspeed.zero.Init(enabled=(zero_stage == 3)):
d3pm = D3PM(
DDiT_Llama(N, dim=768, n_layers=8),
1000,
num_classes=N,
hybrid_loss_coeff=0.0,
).cuda()
total_params = sum(p.numel() for p in d3pm.parameters())
size_in_bytes = total_params * 4
size_in_gb = size_in_bytes / (1024**3)
logger.info(
f"Model Size: {size_in_bytes}, {size_in_gb} GB, Total Param Count: {total_params / 1e6} M"
)
dataset = WikiTextDataset(max_seq_length=max_seq_length, debug=False)
train_sampler = (
RandomSampler(dataset)
if local_rank == -1
else DistributedSampler(dataset, seed=seed)
)
dataloader = DataLoader(
dataset,
collate_fn=default_data_collator,
sampler=train_sampler,
batch_size=per_device_train_batch_size,
)
optimizer = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=num_train_epochs * math.ceil(len(dataloader)),
)
d3pm.train()
model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
model=d3pm, config=ds_config, lr_scheduler=lr_scheduler, optimizer=optimizer
)
global_step = 0
##### actual training loop
if global_rank == 0:
wandb.init(
project="d3pm_wiki",
name=run_name,
config={
"N": N,
"max_seq_length": max_seq_length,
"num_train_epochs": num_train_epochs,
"learning_rate": learning_rate,
"offload": offload,
"train_batch_size": train_batch_size,
"per_device_train_batch_size": per_device_train_batch_size,
"zero_stage": zero_stage,
"seed": seed,
},
)
for i in range(num_train_epochs):
pbar = tqdm(dataloader)
loss_ema = None
for x in pbar:
x = x["input_ids"].to(model_engine.device)
# discritize x to N bins
loss, info = model_engine(x)
model_engine.backward(loss)
model_engine.step()
get_accelerator().empty_cache()
norm = model_engine.get_global_grad_norm()
if global_step % 10 == 0:
if global_rank == 0:
wandb.log({"train_loss": loss, "train_grad_norm": norm})
pbar.set_description(
f"norm: {norm}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}"
)
global_step += 1
if global_step % 600 == 1:
d3pm.eval()
with torch.no_grad():
init_noise = torch.randint(0, N, (16, max_seq_length)).cuda()
outputs = d3pm.sample_with_image_sequence(
init_noise, None, stride=40
)
gen_outputs = []
total = 0
# back to sentence, byte to utf-8
for _i in range(16):
sent = outputs[-1][_i].cpu().tolist()
correctly_parsed = True
try:
sent = b"".join([bytes([i]) for i in sent]).decode("utf-8")
except:
# if there is error, just unicodec
correctly_parsed = False
sent = "".join([chr(i) for i in sent])
sent = (
f"[{_i}] Sample Correctly parsed: "
+ str(correctly_parsed)
+ "\n"
+ sent
)
total += 1 if correctly_parsed else 0
gen_outputs.append(sent)
print(sent)
model_engine.train()
# make a nice html to show the generated outputs
html_formatted = "
".join(gen_outputs)
# log text
if global_rank == 0:
wandb.log(
{
"generated_text": wandb.Html(html_formatted),
"correctly_parsed": total,
}
)
if global_step % 3000 == 1:
save_zero_three_model(
model_engine, global_rank, "./ckpt", zero_stage=zero_stage
)
print(f"Model saved at {global_step}")
if __name__ == "__main__":
main()
================================================
FILE: readme.md
================================================
# Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch
**Special thanks to [fal.ai](https://fal.ai/) for the compute resources for this project.**
This is minimal (400 LOC), but fully faithful implementation of a D3PM [Structured Denoising Diffusion Models in Discrete State-Spaces](https://arxiv.org/abs/2107.03006). in pytorch.
I have tried to keep the code as simple as possible with much comments and explanation that is somewhat lacking on the original jax implementation, so that it is easy to understand. As far as I know, this is the first, faithful reimplementation of D3PM in pytorch. (Please correct me if I am wrong). Of course, this implementation was heavily based on the [official implementation](https://github.com/google-research/google-research/tree/master/d3pm/images).
Difference between this implementation and the official implementation:
* This one has conditional sampling, so as you can see, generations are class-conditioned.
* This one uses rather different/simple model architecture.
* This one simplfies the official implementation very very much, so it is 400 LOC.
* This one does not use truncated logistic reparameterization, but you can use that if you wish.
* Only has uniform sample with inverse-linear beta scheudule, but you can change that with couple loc as well.
## Usage
Following is completely self-contained example.
```bash
python d3pm_runner.py
```
Following uses dit.py, for CIFAR-10 dataset.
```bash
python d3pm_runner_cifar.py
```
## Requirements
Install torch, torchvision, pillow, tqdm
```bash
pip install torch torchvision pillow tqdm
```
## Citation
This implementation:
```bibtex
@misc{d3pm_pytorch,
author={Simo Ryu},
title={Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch},
year={2024},
howpublished={\url{https://github.com/cloneofsimo/d3pm}}
}
```
Original Paper:
```bibtex
@article{austin2021structured,
title={Structured denoising diffusion models in discrete state-spaces},
author={Austin, Jacob and Johnson, Daniel D and Ho, Jonathan and Tarlow, Daniel and Van Den Berg, Rianne},
journal={Advances in Neural Information Processing Systems},
volume={34},
pages={17981--17993},
year={2021}
}
```
================================================
FILE: run_multigpu.sh
================================================
export WORLD_SIZE=$(nvidia-smi -L | wc -l)
deepspeed --num_gpus $WORLD_SIZE lm_deepspeed.py --learning_rate 1e-4
================================================
FILE: test.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"def get_logits_from_logistic_pars(loc, log_scale, num_classes = 10):\n",
" loc = loc.unsqueeze(-1)\n",
" log_scale = log_scale.unsqueeze(-1)\n",
"\n",
" inv_scale = (-log_scale + 2.0).exp()\n",
" \n",
" bin_width = 2.0 / (num_classes - 1)\n",
" bin_centers = torch.linspace(-1.0, 1.0, num_classes).to(loc.device)\n",
" bin_centers = bin_centers.reshape((*loc.shape, num_classes))\n",
" bin_centers = bin_centers - loc\n",
" log_cdf_min = -torch.log1p((-inv_scale * (bin_centers - 0.5 * bin_width)).exp())\n",
" log_cdf_plus = -torch.log1p((-inv_scale * (bin_centers + 0.5 * bin_width)).exp())\n",
" logits = log_minus_exp(log_cdf_plus, log_cdf_min)\n",
" return logits\n",
"\n",
"def log_minus_exp(a, b, epsilon=1.e-6):\n",
" return a + torch.log1p(-torch.exp(b - a) + epsilon)\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"loc = torch.randn(1, 1, 1, 1).clip(-1, 1)\n",
"log_scale = torch.randn(1, 1, 1, 1)\n",
"\n",
"num_classes = 10\n",
"logits = get_logits_from_logistic_pars(loc, log_scale, num_classes)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.4609, -1.4102, -1.8315, -2.5864, -3.5119, -4.5085, -5.5319, -6.5647,\n",
" -7.6002, -8.6343]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"logits[0,0,0,0,:]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGiCAYAAADJO+2bAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABQWUlEQVR4nO3deVxU5f4H8M+wzLAIw76og4IbbriLoKWVaaalVlZmKmau3MzsVnhvZrahV+t2f+aWuWUZmblUZqallor7rqCiKIuACDIDCAPMPL8/zClSEXTgOQOf9+t1Xq/m8JzhezjCfDrnWVRCCAEiIiIiBbKTXQARERHR7TCoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYkkPKvn5+Zg8eTIaNWoEZ2dnREZGYv/+/bLLIiIiIgWQHlRefPFFbNmyBStXrsTx48fRp08f9O7dG+np6bJLIyIiIslUMhclLCoqgpubGzZs2ID+/ftb9nfq1An9+vXDe++9J6s0IiIiUgAHmd+8rKwMJpMJTk5O5fY7Oztj586dtzzGaDTCaDRaXpvNZuTm5sLb2xsqlapa6yUiIiLrEEIgPz8f9evXh51dBQ94hGQRERGiZ8+eIj09XZSVlYmVK1cKOzs70bx581u2nz59ugDAjRs3bty4casFW2pqaoU5QeqjHwA4d+4cXnjhBfz222+wt7dHx44d0bx5cxw8eBAJCQk3tf/7HRW9Xo+goCCkpqbC3d29JksnIiKiu2QwGKDT6ZCXlwetVnvbdlIf/QBAkyZNsGPHDhQWFsJgMCAwMBDPPPMMQkJCbtleo9FAo9HctN/d3Z1BhYiIyMbcqduG9FE/N7i6uiIwMBBXr17F5s2bMXDgQNklERERkWTS76hs3rwZQgi0aNECSUlJeO211xAaGopRo0bJLo2IiIgkk35HRa/XIzo6GqGhoRgxYgR69OiBzZs3w9HRUXZpREREJJn0zrT3ymAwQKvVQq/Xs48KERGRjajs57f0OypEREREt8OgQkRERIrFoEJERESKxaBCREREisWgQkRERIrFoEJERESKxaBCREREisWgQkRERIolNaiYTCZMmzYNwcHBcHZ2RpMmTfDuu+/CxuegIyIiIiuRutbPrFmzsGDBAqxYsQKtW7fGgQMHMGrUKGi1WkyaNElmaURERKQAUoPK7t27MXDgQPTv3x8A0LhxY3z11VfYt2+fzLKIiIhIIaQ++omMjMQvv/yCM2fOAACOHj2KnTt3ol+/frc9xmg0wmAwlNuIiIiodpJ6RyUmJgYGgwGhoaGwt7eHyWTC+++/j2HDht32mNjYWMyYMaMGqyQiIiJZpN5RWb16Nb788kusWrUKhw4dwooVKzBnzhysWLHitsdMnToVer3esqWmptZgxURERFSTVELiEBudToeYmBhER0db9r333nv44osvkJiYWKn3qOwy0URERKQclf38lnpH5dq1a7CzK1+Cvb09zGazpIqIiIhISaT2UXnsscfw/vvvIygoCK1bt8bhw4fx0Ucf4YUXXpBZFhERESmE1Ec/+fn5mDZtGtatW4fLly+jfv36GDp0KN566y2o1epKvQcf/RAREdmeyn5+Sw0q1sCgQkREZHtsoo8KERERUUUYVIiIiEixGFSIiIhIsRhUiIiISLEYVIiIiEixGFSIiIhIsRhUiIiISLEYVIiIiEixGFSIiIhIsaQHlcaNG0OlUt20/XVFZSIiIqqbpC5KCAD79++HyWSyvD5x4gQefvhhDBkyRGJVREREpATSg4qvr2+51zNnzkSTJk3Qs2dPSRURERGRUkgPKn9VUlKCL774AlOmTIFKpbplG6PRCKPRaHltMBhqqjwiIiKqYdL7qPzV+vXrkZeXh6ioqNu2iY2NhVartWw6na7mCiQiIqIapRJCCNlF3NC3b1+o1Wp8//33t21zqzsqOp3ujstEExERkXIYDAZotdo7fn4r5tHPxYsXsXXrVqxdu7bCdhqNBhqNpoaqIiIiIpkU8+hn2bJl8PPzQ//+/WWXQkRERAqhiKBiNpuxbNkyjBw5Eg4OirnJQ0RERJIpIqhs3boVKSkpeOGFF2SXQkRERAqiiNsXffr0gYL69BIREZFCKOKOChEREdGtMKgQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYkkPKunp6Xj++efh7e0NZ2dntG3bFgcOHJBdFhERESmA1Jlpr169iu7du+OBBx7Apk2b4Ovri7Nnz8LT01NmWURERKQQUoPKrFmzoNPpsGzZMsu+4OBgiRURERGRkkh99PPdd9+hc+fOGDJkCPz8/NChQwcsXry4wmOMRiMMBkO5jYiIiGonqUHl/PnzWLBgAZo1a4bNmzdjwoQJmDRpElasWHHbY2JjY6HVai2bTqerwYqJiIioJqmExGWL1Wo1OnfujN27d1v2TZo0Cfv370d8fPwtjzEajTAajZbXBoMBOp0Oer0e7u7u1V4zERER3TuDwQCtVnvHz2+pd1QCAwPRqlWrcvtatmyJlJSU2x6j0Wjg7u5ebiMiIqLaSWpQ6d69O06fPl1u35kzZ9CoUSNJFREREZGSSA0qr7zyCvbs2YMPPvgASUlJWLVqFT799FNER0fLLIuIiIgUQmpQ6dKlC9atW4evvvoKbdq0wbvvvouPP/4Yw4YNk1kWERERKYTUzrTWUNnOOERERKQcNtGZloiIiKgiDCpERESkWAwqREREpFgMKkRERKRYDCpERESkWAwqREREpFgMKkRERKRYDCpERESkWAwqREREpFjSg8rbb78NlUpVbgsNDZVdFhERESmAg+wCAKB169bYunWr5bWDgyLKIiIiIskUkQgcHBwQEBBQqbZGoxFGo9Hy2mAwVFdZREREJJn0Rz8AcPbsWdSvXx8hISEYNmwYUlJSbts2NjYWWq3Wsul0uhqslIiIiGqS9NWTN23ahIKCArRo0QIZGRmYMWMG0tPTceLECbi5ud3U/lZ3VHQ6HVdPJiIisiGVXT1ZelD5u7y8PDRq1AgfffQRRo8efcf2lT1RIiIiUo7Kfn4r4tHPX3l4eKB58+ZISkqSXQoRERFJprigUlBQgHPnziEwMFB2KURERCSZ9KDyz3/+Ezt27MCFCxewe/duDB48GPb29hg6dKjs0oiIiEgy6cOT09LSMHToUOTk5MDX1xc9evTAnj174OvrK7s0IiIikkx6UImLi5NdAhERESmU9Ec/RERERLfDoEJERESKxaBCREREisWgQkRERIrFoEJERESKxaBCREREisWgQkRERIrFoEJERESKpaigMnPmTKhUKkyePFl2KURERKQAigkq+/fvx6JFixAWFia7FCIiIlIIRQSVgoICDBs2DIsXL4anp6fscoiIiEghFBFUoqOj0b9/f/Tu3fuObY1GIwwGQ7mNiIiIaidFLEp46NAh7N+/v1LtY2NjMWPGjGquioiIiJRA6h2V1NRUvPzyy/jyyy/h5ORUqWOmTp0KvV5v2VJTU6u5SiIiIpJFJYQQsr75+vXrMXjwYNjb21v2mUwmqFQq2NnZwWg0lvvarRgMBmi1Wuj1eri7u1d3yURERGQFlf38lvro56GHHsLx48fL7Rs1ahRCQ0Pxxhtv3DGkEBERUe0mNai4ubmhTZs25fa5urrC29v7pv1ERERU9yhi1A8RERHRrUgf9fN327dvl10CERERKQTvqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWJJDyoLFixAWFgY3N3d4e7ujoiICGzatEl2WURERKQA0oNKw4YNMXPmTBw8eBAHDhzAgw8+iIEDB+LkyZOySyMiIiLJVEIIIbuIv/Py8sLs2bMxevTom75mNBphNBotrw0GA3Q6HfR6Pdzd3WuyTCIiIrpLBoMBWq32jp/f0u+o/JXJZEJcXBwKCwsRERFxyzaxsbHQarWWTafT1XCVREREVFMUcUfl+PHjiIiIQHFxMerVq4dVq1bh0UcfvWVb3lEhIiKyfZW9o+JQgzXdVosWLXDkyBHo9XqsWbMGI0eOxI4dO9CqVaub2mo0Gmg0GglVEhERUU1TxB2Vv+vduzeaNGmCRYsW3bFtZRMZERERKYdN9lG5wWw2l3u8Q0RERHWT9Ec/U6dORb9+/RAUFIT8/HysWrUK27dvx+bNm2WXRkRERJJJDyqXL1/GiBEjkJGRAa1Wi7CwMGzevBkPP/yw7NKIiIhIMulBZcmSJbJLICIiIoVSZB8VIiIiIoBBhYiIiBSMQYWIiIgUi0GFiIiIFItBhYiIiBSLQYWIiIgUi0GFiIiIFItBhYiIiBRLelCJjY1Fly5d4ObmBj8/PwwaNAinT5+WXRYREREpgPSgsmPHDkRHR2PPnj3YsmULSktL0adPHxQWFsoujYiIiCRTCSGE7CL+Kjs7G35+ftixYwfuv//+O7av7DLRREREpByV/fyWvtbP3+n1egCAl5fXLb9uNBphNBotrw0GQ43URURERDVP+qOfvzKbzZg8eTK6d++ONm3a3LJNbGwstFqtZdPpdDVcJREREdUURT36mTBhAjZt2oSdO3eiYcOGt2xzqzsqOp2Oj36IiIhsiM09+vnHP/6BH374Ab/99tttQwoAaDQaaDSaGqyMiIiIZJEeVIQQeOmll7Bu3Tps374dwcHBsksiIiIihZAeVKKjo7Fq1Sps2LABbm5uyMzMBABotVo4OztLro6IiIhkkt5HRaVS3XL/smXLEBUVdcfjOTyZiIjI9thMHxUF9eUlIiIihVHU8GQiIiKiv2JQISIiIsViUCEiIiLFYlAhIiIixWJQISIiIsViUCEiIiLFYlAhIiIixWJQISIiIsViUCEiIiLFkh5UfvvtNzz22GOoX78+VCoV1q9fL7skIiIiUgjpQaWwsBDt2rXDvHnzZJdCRERECiN9rZ9+/fqhX79+lW5vNBphNBotrw0GQ3WURURERAog/Y5KVcXGxkKr1Vo2nU4nuyQiIiKqJjYXVKZOnQq9Xm/ZUlNTZZdERERE1UT6o5+q0mg00Gg0sssgIiKiGmBzd1SIiIio7mBQISIiIsWS/uinoKAASUlJltfJyck4cuQIvLy8EBQUJLEyIiIikk16UDlw4AAeeOABy+spU6YAAEaOHInly5dLqoqIiIiUQHpQ6dWrF4QQsssgIiIiBWIfFSIiIlIsBhUiIiJSLAYVIiIiUiwGFSIiIlIsBhUiIiJSLAYVIiIiUiwGFSIiIlIsBhUiIiJSLEUElXnz5qFx48ZwcnJCeHg49u3bJ7skIiIiUgDpQeXrr7/GlClTMH36dBw6dAjt2rVD3759cfnyZdmlERERkWTSg8pHH32EMWPGYNSoUWjVqhUWLlwIFxcXLF26VHZpREREJJnUoFJSUoKDBw+id+/eln12dnbo3bs34uPjb3mM0WiEwWAotxEREVHtJDWoXLlyBSaTCf7+/uX2+/v7IzMz85bHxMbGQqvVWjadTlcTpRIREZEE0h/9VNXUqVOh1+stW2pqquySiIiIqJo4yPzmPj4+sLe3R1ZWVrn9WVlZCAgIuOUxGo0GGo2mJsojIiIiyaTeUVGr1ejUqRN++eUXyz6z2YxffvkFEREREisjIiIiJZB6RwUApkyZgpEjR6Jz587o2rUrPv74YxQWFmLUqFGySyMiIiLJpAeVZ555BtnZ2XjrrbeQmZmJ9u3b46effrqpgy0RERHVPSohhJBdxL0wGAzQarXQ6/Vwd3eXXQ4RERFVQmU/v21u1A8RERHVHQwqREREpFgMKkRERKRYDCpERESkWAwqREREpFgMKkRERKRY0udRISICgJIyM/RFpZbN/LeZE+ztVNA6O1o2R3v+fxZRXcCgQkQ1QgiBLIMRJy/pkXylECm5165vOdeQaSjGtRJTld7PVW2PAK0Tgrxc0MjbFUFeLgj2cUXrBu7wc3OqprMgoprGoEJE1aK41IT9F3Jx8OJVHE/T41i6Htn5xgqPUakAdydHuDs7wNGu/B2TEpMZhqJSGIrLAACFJSacyy7EuexCANnl2ga4O6FtQy3CGmjRqZEnOjX2hMbB3qrnR0Q1Q2pQef/997Fx40YcOXIEarUaeXl5MsshontgNgucvGTA70nZ2JV0BfsvXEVJmblcG3s7FZr51UMT33oI8na5fjfEywX1PZzh4eIINydH2NupKvw+JrOAoagUeUWlSL9ahIu5hZY7M0mXC5CUXYBMQzEyTxVjy6nrK7M7Odqha7A3ejT1xn3NfBEa4AaVquLvQ0TKIHUK/enTp8PDwwNpaWlYsmTJXQUVTqFPJI/ZLHA49So2HsvEphMZyNAXl/t6oNYJ3UK80a6hFm0beqBVoDuc1dV7Z6PQWIZTGQYcS9PjaGoe9pzPweW/3cnReTnj0TaBeLRtIMIaahlaiCSo7Oe3Itb6Wb58OSZPnlypoGI0GmE0/vlHx2AwQKfTMagQ1aAzWfn4en8qNh7LQKbhz3DiqrZHRBMf3NfMBz2a+SDEx1V6CBBC4OzlAvx+9gp2ns1G/PkcFJf+eaenoaczBoTVx9OdGyLEt57ESonqlsoGFZvroxIbG4sZM2bILoOozik0lmHjsQzE7U/BoZQ8y/56Ggf0bumHR9sG4v7mvnByVFZfEJVKheb+bmju74bRPYJxraQM209nY+PxDPyacBlpV4uwcMc5LNxxDuHBXni2qw792gQq7jyI6ireUSGiCl3MKcTSncn49lA6CozXO7La26nwUKgfnurUUJHhpLKKSkzYdvoy1hxMw/bTl2H+46+hu5MDnumiQ1T3YDTwcJZbJFEtJe2OSkxMDGbNmlVhm4SEBISGht7V+2s0Gmg0mrs6logq7+DFXCz+LRmbT2Xixv/ONPZ2wTNdgvBkpwa1Ygiws9oej7a93lflUl4R1hxMw9f7U5GeV4TFvydj6a4L6N82EC/eF4ywhh6yyyWqk6x+RyU7Oxs5OTkVtgkJCYFarba8rsodlb9jZ1oi6xFCYMeZbPzfL2fLPd7p1cIXo3sEo0dTH+l9Tqqb2Syw/cxlfPZ7Mnaf+/NvWbcQL0x6qBkiQrxr/c+AqCZIu6Pi6+sLX19fa78tEVWjGwHl461ncSQ1DwCgtrfD4A4NMPq+YDT3d5NbYA2ys1PhwVB/PBjqjxPpeizZmYzvj17CnvO52HN+L7oGe2Fy72aIbOIju1SiOkFqZ9qUlBTk5uYiJSUFJpMJR44cAQA0bdoU9eqx9z1RTdh59grm/HzaElCcHO0wvFsjjLk/pFY83rkXbRpo8d9n2uP1R1pg4fZz+GpfKvYl5+K5xXsRHuyFf/ZtgS6NvWSXSVSrSe1MGxUVhRUrVty0f9u2bejVq1el3oOPfojuTkKGAbGbEvHbmeuzut4IKGPvbwJfN/YDu5UMfREWbD+HuH2pKDFdH+Lcp5U/3ugXiiYc2kxUJTY1j8q9YFAhqppMfTE+/Pk01hxKgxCAo70Kw8IbIfqBpgwolZShL8LcX5Pw9f5UmMwC9nYqDO2qw8sPNefPkKiSGFSIqJziUhMW/3Ye87YnWSY86982EK8/0gKNvF0lV2ebki7nY+amRGxNuAzg+oR3L/duhqjIYKgduLozUUUYVIjIYuupLLzzwymk5F4DAHRp7Il/PdoSHYI8JVdWO+w5n4MPfkzAsTQ9AKCJrytmPN4GPZqxwy3R7TCoEBEuXCnEOz+cwq+J1/+PP8DdCf/u3xIDwgI5xNbKzGaBNYfSMGtTInIKSwAA/doE4M0BrThpHNEtMKgQ1WGlJjM+/e08/vfLWZSUmeFor8LoHiF46cGmcNXY3MoZNkVfVIr/bjmDz+MvwCwAF7U9Xu3TAlGRje+4MjRRXcKgQlRHHUq5iqnfHsfprHwAQI+mPpgxsDVHpdSwhAwD3tpwAvsvXAUAhDXU4oPBbdGmgVZyZUTKwKBCVMcUGMvwn58SsXLPRQgBeLmq8daAVhjYvj4f80hiNgvE7U9F7KYE5BeXwd5OhRd7BGNy7+ZwVtvm+khE1sKgQlSH7Eq6gtfXHEN6XhEA4MmODfHv/i3h5aq+w5FUEy4bijHj+1PYeDwDABDs44o5Q8LQqREni6O6i0GFqA4oMJYh9scEfLk3BQDQ0NMZM58I42gThdp6Kgv/Xn8cWQYjVCpgdPdg/LNvC5tdfZroXlT281vaQP8LFy5g9OjRCA4OhrOzM5o0aYLp06ejpKREVklENmX3uSt45OPfLCFleLdG2Dz5foYUBevdyh8/T+6JJzs2hBDAZzuT8ej/fsfBi1dll0akWNK6/ycmJsJsNmPRokVo2rQpTpw4gTFjxqCwsBBz5syRVRaR4hWXmjB782ks2ZkM4PpdlP88GYbIpgwotkDr4ogPn26H/mEBiPn2OM5fKcSQhbsR/UBTTHqoGRztOVEc0V8p6tHP7NmzsWDBApw/f77Sx/DRD9Ulpy4ZMPnrwziTVQAAeC48CP96tCXqccixTdJfK8Xb35/EusPpAIC2fyyC2NSPI7So9lP8o59b0ev18PKquHOZ0WiEwWAotxHVdiazwMId5zBw3k6cySqATz01lkZ1xgeD2zKk2DCtiyP++0x7zHuuI7TOjjierseAub/j8/gLUND/QxJJpZigkpSUhLlz52LcuHEVtouNjYVWq7VsOp2uhiokkiNTX4xhn+3BzE2JKDUJPNzKH5sn348HQ/1ll0ZW0j8sEJsn34/7mvmguNSMtzacxOgVB5BTYJRdGpF0Vn/0ExMTg1mzZlXYJiEhAaGhoZbX6enp6NmzJ3r16oXPPvuswmONRiOMxj9/eQ0GA3Q6HR/9UK205VQWXltzFHnXSuGitsdbA1rhmS46zotSS5nNAiviLyB2UyJKyszwddPgv0+3ZwdpqpWkDU/Ozs5GTk5OhW1CQkKgVl+f3+HSpUvo1asXunXrhuXLl8POrmo3edhHhWqj4lITYn9MwIr4iwCANg3c8X/PdkAIZ5etExIyDJj01WGcvVwAlQoYd38TvNqnOTvaUq1iE/OopKen44EHHkCnTp3wxRdfwN6+6nMJMKhQbZN0uQD/WHUIiZnXp8Afc9/1uTY0Dpxroy4pKjHhvY2nLMPP2zXU4pPnOkLn5SK5MiLrUHxQSU9PR69evdCoUSOsWLGiXEgJCAio9PswqFBtsvZQGt5cfwLXSkzwqafGnCHt0KuFn+yySKKfTmTgjW+PQ19UCjcnB8x+KgyPtAmUXRbRPVN8UFm+fDlGjRp1y69VpSQGFaoNikpMeGvDCXxzMA0AENnEGx8/2x5+bk6SKyMlSM8rwkurDuFQSh4AYGREI/yrf0veZSObpvigYi0MKmTrki7nY+KXh3Am63p/hMkPNcc/HmwKezt2mKU/lZrMmPPzaSzacX2eqTYN3DHvuY5o5O0quTKiu2OT86gQ1TUbjqTjsbm7cCarAL5uGnz5Yjhe7t2MIYVu4mhvh6n9WmJpVGd4uDjiRLoBA+buxM8nM2WXRlStGFSIJDCWmTBt/Qm8HHcERaUmRDbxxo+T7kNkEw5DpYo9GOqPHyfdh45BHsgvLsPYlQcR+2MCykxm2aURVQsGFaIalnb1Gp5eGI+Ve64PPX7pwaZYOTocvm4ayZWRrajv4Yy4sRF4oXswAGDRb+fx3Gd7cdlQLLkyIutjUCGqQTvOZGPA3J04mqaH1tkRy6K64NU+Lfioh6pM7WCHtx5rhfnDOqKexgH7knPx6P/txN7zFc9jRWRrGFSIaoDZLDD3l7OIWrYPeddK0a6hFhsn9cADoRx6TPfm0baB+O4f3dHC3w1XCox47rO9+Oz381wriGoNBhWiaqYvKsWYzw/gwy1nIMT1FY9Xj49AQ09O3EXWEeJbD+uiIzGwfX2YzALvbUzAS18dRqGxTHZpRPeMy64SVaPETAPGrTyIiznXoHaww3uD2uDpzlxIk6zPRe2Aj59pjw46D7y3MQE/HMvAmax8LHy+E5deIJvGOypE1eT7o5cweN5uXMy5hoaezlg7IZIhhaqVSqVCVPdgxI3tBj83Dc5kFWDgJ7vwS0KW7NKI7prUoPL4448jKCgITk5OCAwMxPDhw3Hp0iWZJRHdszKTGR/8eP3We1GpCfc188H3/+iBNg20skujOqJzYy/8MKkHujT2RL6xDKNXHMDHW8/AbGa/FbI9UoPKAw88gNWrV+P06dP49ttvce7cOTz11FMySyK6J7mFJRi5bB8+/e367KETejXB8lFd4emqllwZ1TV+bk748sVuGBHRCADw8dazGLvyIAzFpZIrI6oaRU2h/91332HQoEEwGo1wdHSs1DGcQp+U4uQlPcZ+fhDpeUVwUdtj9lPt0D+Mi8eRfN8cSMW/159ASZkZIT6u+HREZzT1Y78VksvmptDPzc3Fl19+icjIyApDitFohMFgKLcRyfb90Ut4csFupOcVoZG3C9ZN7M6QQooxpLMOa8ZHIFDrhPNXCjF4HvutkO2QHlTeeOMNuLq6wtvbGykpKdiwYUOF7WNjY6HVai2bTsfOiSSPySww66dEvPTVYRSXmtGzuS++i+6BFgFusksjKiesoQe+f6kHujb2Qr6xDC9+fgCf/HqW862Q4ln90U9MTAxmzZpVYZuEhASEhoYCAK5cuYLc3FxcvHgRM2bMgFarxQ8//ACV6tYzdRqNRhiNRstrg8EAnU7HRz9U4/RFpXg57jC2n84GAIzrGYLX+4ZylllStJIyM9794ZRlCYdH2wZg9lPt4KrhbBVUsyr76MfqQSU7Oxs5ORVP4RwSEgK1+ubOhWlpadDpdNi9ezciIiIq9f3YR4VkOJddgDErDuD8lUI4Odph1pNhGNi+geyyiCotbl8Kpm04gVKTQGiAGxaP6AydFychpJpT2c9vq0doX19f+Pr63tWxZvP11T//eseESGm2nb6MSasOI99YhgYezlg0vBOHHpPNebZrEJr518O4lYeQmJmPgfN2Yf6wjugW4i27NKJypI362bt3L/bv348ePXrA09MT586dw7Rp05CVlYWTJ09Co6ncSrK8o0I1RQiBT387j5k/JUIIoGtjL8x/viN86nHVY7JdGfoijP38II6n6+Fgp8L0x1tjeLdGssuiOkDxo35cXFywdu1aPPTQQ2jRogVGjx6NsLAw7Nixo9IhhaimFJea8MrXRxC76XpIGdo1CF+8GM6QQjYvUOuMb8ZH4PF29VFmFpi2/gT+te44SsrMsksjAqCweVTuBu+oUHXLMhRj7OcHcDRND3s7Fd5+rBWe79both2+iWyREAILd5zHfzZfD+PhwV5Y8HwneHGyQqomir+jQmQLjqXl4fFPduJomh4eLo5YOborhkc0ZkihWkelUmFCryb4bERn1NM4YG9yLgbO24nTmfmyS6M6jkGF6Da+O3oJQxbGI8tgRDO/etgQ3R2RTXxkl0VUrR5q6Y+1EyMR5OWC1NwiPDF/F7ac4uRwJA+DCtHfmM0CH/58GpO+OgxjmRkPhfph7cRINPJ2lV0aUY1o7u+GDdHd0S3EC4UlJoxdeQDztydxcjiSgkGF6C8KjWWY8OVBzP01CcD1Sdw+HdEZbk6VW3uKqLbwdFVj5ehwPN8tCEIA//npNF5dfRTFpSbZpVEdw6BC9IdLeUUYsjAem09mQW1vhw+HtMPUfi050yzVWY72dnhvUFu8O7A17O1UWHs4HUMX70F2Pue6oprDoEIE4FDKVTz+yS6cyjDAp54aX40Nx5OdGsoui0gRhkc0xopRXeHu5IDDKXkY+MlOnLykl10W1REMKlTnrT+cjmc/3YMrBUaEBrhhfXR3dGrkJbssIkXp0cwH66O7I8THFZf0xXhqQTw2n8yUXRbVAQwqVGeZzQJzNp/G5K+PoKTMjIdb+ePbCZFo6Mn1TohuJcS3HtZN7I77mvmgqNSE8V8cxILt59jJlqoVgwrVSddKyhC96hA+2Xa90+yEXk2w6PlOXEGW6A60Lo5YFtUFIyIaQQhg1k+JePWbozCWsZMtVQ9FBBWj0Yj27dtDpVLhyJEjssuhWi5TX4ynF8Vj04lMONqrMGdIO7zxSCjs2GmWqFIc7O3wzsA2eOdGJ9tD6Ri2eC9yCtjJlqxPEUHl9ddfR/369WWXQXXA8TQ9Bs7biRPpBni5qrFqTDc8xU6zRHdlRERjLIvqAjcnBxy4eBUD5+3CmSzOZEvWJT2obNq0CT///DPmzJkjuxSq5TYdz8CQRbstM82un9gdXRqz0yzRvbi/uS/WTYxEI28XpF0twhPzd2P76cuyy6JaRGpQycrKwpgxY7By5Uq4uFSuA6PRaITBYCi3EVVECIF525Iw4ctDKC41o2dzX3w7MRJB3uw0S2QNTf3csH5id3QN9kKBsQwvLN+P5buS2cmWrEJaUBFCICoqCuPHj0fnzp0rfVxsbCy0Wq1l0+l01Vgl2TpjmQmvrj6K2ZtPAwCiIhtjycjOcOdMs0RW5emqxhejw/F054YwC+Dt709h2oYTKDWZZZdGNs7qQSUmJgYqlarCLTExEXPnzkV+fj6mTp1apfefOnUq9Hq9ZUtNTbX2KVAtkVNgxLDFe7H2cDrs7VR4d1AbvP14azjYS3/iSVQrqR3sMOvJMEztFwqVCvhiTwpeWL4f+qJS2aWRDVMJK9+by87ORk5OToVtQkJC8PTTT+P777+HSvXnSAuTyQR7e3sMGzYMK1asqNT3MxgM0Gq10Ov1cHd3v6faqfY4m5WPF1bsR2puEdycHDB/WEfc18xXdllEdcbPJzMx+esjuFZiQlO/elg6sgsft1I5lf38tnpQqayUlJRy/UsuXbqEvn37Ys2aNQgPD0fDhpUbicGgQn/3+9lsTPzyEPKLyxDk5YKlUZ3R1M9NdllEdc6JdD1eXHEAmYZieLmqsWh4J3ZgJwvFB5W/u3DhAoKDg3H48GG0b9++0scxqNBffbHnIqZ/dxIms0CXxp5YNLwzvFzVsssiqrOyDMV4ccUBHE/XQ21vh5lPtsUTHTklAFX+85sP66lWMJkFZnx/Em+uPwGTWeCJDg3wxYvhDClEkvm7O+Hrcd3wSOsAlJjMmLL6KD78+TTMZkX8PzLZAMXcUblbvKNCBcYyvPzVYfySeH3uhn/2aY7oB5qW6/9ERHKZzQKzfz6NBdvPAQAGhAVizpB2cHK0l1wZyVLZz28ubEI27VJeEV5Yvh+JmfnQONjhw6fbYUAYZzkmUho7OxXeeCQUwT6u+Nfa4/jhWAbS84rw6fDO8HXTyC6PFIyPfshmHUvLw8B5u5CYmQ+fehrEje3GkEKkcE931mHl6HBonR1xOCUPgzjtPt0BgwrZpJ9OZODpRfHIzjeihb8b1kdHokOQp+yyiKgSIpp4Y93ESDT2dkF6XhGenL8bO85kyy6LFIpBhWyKEAILd5zD+C/+nA5/zYQINPTk/AxEtiTEtx7WTeyO8GAv5P8x7f7KPRdll0UKxKBCNqPUZEbMt8cxc1MiAGBERCMsGdkZbpwOn8gmebqqsXJ0OJ7s2BAms8C09SfwzvenYOKIIPoLdqYlm6C/VoqJqw5iV1IO7FTAWwNaIap7sOyyiOgeqR3sMGdIGEJ8XTF782ks3ZWMlNxC/O/ZDnDV8COKeEeFbEBKzjU8sWAXdiXlwFVtj89GdmZIIapFVCoVoh9oik+e6wCNgx22JlzGkIXxyNAXyS6NFIBBhRTt4MVcDJq/C+eyCxGodcI34yPxYKi/7LKIqBoMCKuPr8Z2g089NU5lGDBo3i6cSNfLLoskkxpUGjdufNPKyjNnzpRZEinIhiPpGLp4L3ILS9C2gRYborujVX1O6kdUm3UM8sS6id3RzK8esgxGDFkYjy2nsmSXRRJJv6PyzjvvICMjw7K99NJLsksiyYQQ+L9fzuLluCMoKTOjTyt/fD2uG/zcnWSXRkQ1QOflgjUTInFfMx8UlZowduUBfPb7edj4ROp0l6QHFTc3NwQEBFg2V1dX2SWRRMYyE15dfRQfbTkDABh7fwgWPt8JLmp2qiOqS7TOjlga1QVDuwZBCOC9jQmYtuEEykxm2aVRDZO61k/jxo1RXFyM0tJSBAUF4bnnnsMrr7wCB4fbfygZjUYYjUbLa4PBAJ1Ox7V+aoGrhSUY98VB7EvOhb2dCu8MbI1h4Y1kl0VEEgkhsPj384jdlAghgJ7NffHJcx04LUEtYBOrJ0+aNAlxcXHYtm0bxo0bhw8++ACvv/56hcfExsZCq9VaNp1OV0PVUnVKvlKIJxbsxr7kXLhpHLAsqgtDChFBpVJh7P1NsGBYJzg52mHHmWwMWRiP9DyOCKorrH5HJSYmBrNmzaqwTUJCAkJDQ2/av3TpUowbNw4FBQXQaG69SBXvqNQ++5JzMXblAeRdK0UDD2csjeqCFgFusssiIoU5lpaH0SsOIDvfCF83DZaM7Iywhh6yy6K7VNk7KlYPKtnZ2cjJyamwTUhICNRq9U37T548iTZt2iAxMREtWrSo1Per7ImSMq07nIY31hxHicmMdg21WDyyM/zc2GmWiG4tPa8Io/9YMd3J0Q7/e7YD+rYOkF0W3YXKfn5bvYeir68vfH197+rYI0eOwM7ODn5+flauipRGCIGPt57F/345CwDo1yYAHz3dHs5qe8mVEZGSNfBwxjfjI/CPVYex40w2xn9xEP/q1xIv3hcMlUoluzyqBtKGUsTHx2Pv3r144IEH4Obmhvj4eLzyyit4/vnn4enJVXBrM2OZCTHfHse6w+kAgHE9Q/BG31DY2fGPDBHdmZuTI5aM7Iy3vz+JL/ak4P0fE5CcU4h3Hm8NB3vpg1nJyqQFFY1Gg7i4OLz99tswGo0IDg7GK6+8gilTpsgqiWrA1cISjFt5EPsuXB/Z896gNhjaNUh2WURkYxzs7fDuwDZo7O2K939MwKq9KUi7WoR5HBFU60gdnmwN7KNiO5KvFOKF5fuRfKUQbhoHzH++I+5rdnePCYmIbvj5ZCZejjuColITQgPcsCSqCxp4OMsui+7AJoYnU92xLzkXg+fvQvKVQjTwcMa3EyMZUojIKvq0DsDqcRHwddMgMTMfg+btwrG0PNllkZUwqFC1W3c4Dc9/thd510rRTueB9dHd0dyfw4+JyHraNtRifXR3hAa4ITvfiKcXxWPzyUzZZZEVMKhQtbk+sucMXvn6KEpMZvRrE4C4Md3g63brOXKIiO7FjRFBPZv7orjUjPFfHOQaQbUAgwpVC2OZCVNWH8XHW68PPx53fwjmPdeRw4+JqFrdGBH0fDeuEVRbcKU3srq8ayUYu/LPNXs4soeIatLfRwR9sScFqblFXCPIRvGOClnVhSuFGDz/zzV7lo/qwpBCRDVOpVLhxfuur77+1zWCLnGNIJvDoEJWs/9C+ZE9ayZwZA8RydX3FiOCjqfpZZdFVcCgQlax4Ug6hi3ei6vXShHWUIt10ZFcWJCIFCGs4fXRhi383XD5jxFBW05lyS6LKolBhe6JEAJzfzmLl+OOoMRkRt/W/vh6bAQXFiQiRWng4YxvJkTgvmY+KCo1YezKA1i2K1l2WVQJ0oPKxo0bER4eDmdnZ3h6emLQoEGyS6JKKikz45/fHMOHW84AAMbcF4z5wzpxZA8RKZK7kyOWRl3vNycEMOP7U5jOEUGKJ3XUz7fffosxY8bggw8+wIMPPoiysjKcOHFCZklUSfprpRj3xQHsOX99ZM+Mx1vj+W6NZJdFRFQhR3s7fDC4DYJ9XPDBj4lYEX8RqVeLMHdoB7hqOBBWiaSt9VNWVobGjRtjxowZGD16dKWPMxqNMBqNltcGgwE6nY5r/dSglJxriFq+D+ezC+GqtscnwzrigRZ+sssiIqqSTcczMPnrIzCWmdEq0B1Lo7ogQMvH1jVF8Wv9HDp0COnp6bCzs0OHDh0QGBiIfv363fGOSmxsLLRarWXT6XQ1VDEBwMGLVzF4/i6czy5EoNYJayZEMqQQkU3q1zYQcWO7waeeGqcyDBg0bxdOXTLILov+RlpQOX/+PADg7bffxptvvokffvgBnp6e6NWrF3Jzc2973NSpU6HX6y1bampqTZVc5208loGhi/cgp7AEbRq4Y310d7QM5F0sIrJdHYI8sW5idzTzq4dMQzGGLNyNbYmXZZdFf2H1oBITEwOVSlXhlpiYCLP5euelf//733jyySfRqVMnLFu2DCqVCt98881t31+j0cDd3b3cRtVLCIH525MQveoQSsrM6N3SD6vHRcDfnbdIicj26bxcsGZCJLo39UZhiQmjV+zHyvgLssuiP1i959Crr76KqKioCtuEhIQgIyMDANCqVSvLfo1Gg5CQEKSkpFi7LLpLpSYzpq0/gbj91+9cRUU2xrQBrWBvp5JcGRGR9WidHbEsqiveXH8cqw+kYdqGk7iYcw1TH23Jv3eSWT2o+Pr6wtf3zrORdurUCRqNBqdPn0aPHj0AAKWlpbhw4QIaNeLoESUwFJci+stD+P3sFdipgLcGtEJU92DZZRERVQu1gx1mPRmGRt6umL35ND7bmYyU3Gv4+Nn2cFFzRJAs0vqouLu7Y/z48Zg+fTp+/vlnnD59GhMmTAAADBkyRFZZ9Ie0q9fw1ILd+P3sFbio7bF4RGeGFCKq9VQqFaIfaIr/G9oBagc7/HwqC89+ugeX84tll1ZnSY2Is2fPhoODA4YPH46ioiKEh4fj119/haenp8yy6rxjaXkYveIAsvON8HPTYGlUF7RpoJVdFhFRjXm8XX3U1zphzOcHcCxNj8HzdmPZqC5o7s+lQWqatHlUrKWy47CpcjafzMTLcYdRXGpGaIAblkZ1QX0PZ9llERFJceFKIUYt34/kK4Vw0zhgwfOd0KOZj+yyagXFz6NCyiKEwGe/n8f4Lw6iuNSMns198c34CIYUIqrTGvu4Yu2ESHRt7IV8Yxmilu1D3D4O+KhJDCqEMpMZb204ifc2JkAIYFh4EJaM7Aw3J0fZpRERSefpqsbKF7tiUPv6KDMLxKw9jlk/JcJstukHEjaDQaWOKzCWYcznB7Byz0WoVMC/H22J9wa1gYM9/2kQEd2gcbDHf59pj0kPNQMALNh+Di/FHUZxqUlyZbUfP43qsEx9MZ5eGI9tp7OhcbDD/Oc6Ysz9IVCpOGcAEdHfqVQqTHm4OT4c0g6O9ipsPJaB5xbvQU6B8c4H011jUKmjTl36Y12LDAN86qnx9bgI9GsbKLssIiLFe7JTQ3z+QjjcnRxwKCUPg+fvxrnsAtll1VoMKnXQttOXMWThbmQaitHUrx7WTeyO9joP2WUREdmMiCbeWDuxO3RezkjJvYYn5u/GnvM5ssuqlRhU6piVey5i9PL9KCwxIbKJN76dEAmdl4vssoiIbM6N/9HrEOQBfVEphi/Zi7WH0mSXVeswqNQRJrPAez+cwrT1J2AWwJBODbF8VFdonTmyh4jobvnU0+CrMd3Qv20gSk0CU1YfxX+3nIGNT1GmKNKCyvbt22+7uvL+/ftllVUrFZWYMPHLg/hsZzIA4LW+LfCfp8KgdmBOJSK6V06O9pg7tAPG92wCAPjfL2fx6uqjMJZxRJA1SJuZtqSkBLm5ueX2TZs2Db/88gvOnTtX6ZEnnJm2YpfzizFmxQEcTdNDbW+H2UPCMLB9A9llERHVSl/tS8Gb60/AZBYID/bCouGd4OGill2WIil+Zlq1Wo2AgADL5u3tjQ0bNmDUqFEcHmslZ7LyMXjebhxN08PTxRFfjglnSCEiqkZDuwZhWVQX1NM4YG9yLp5YsBsXcwpll2XTFHPv/7vvvkNOTg5GjRpVYTuj0QiDwVBuo5vtSrqCJxfsRnpeEYJ9XLF2Ynd0aewluywiolrv/ua+WDMhAvW1TjifXYjB83fj4MWrssuyWYoJKkuWLEHfvn3RsGHDCtvFxsZCq9VaNp1OV0MV2o7V+1Mxcuk+5BeXoWtjL6ydEIlgH1fZZRER1RmhAe5YH90dbRtokVtYgqGL92DjsQzZZdkkqweVmJiY23aSvbElJiaWOyYtLQ2bN2/G6NGj7/j+U6dOhV6vt2ypqanWPgWbZTYLzN6ciNe/PYYys8DA9vWx8sWu8HTl81Eioprm5+6Er8d1Q++W/igpMyN61SEs2H6OI4KqyOqdabOzs5GTU/GkNyEhIVCr//zwfPfddzF37lykp6fD0bFqw2XZmfa64lITXltzDN8fvQQAmPRgU7zycHP29yEiksxkFnhv4yks23UBADC0qw7vDGwDxzq+plplP78drP2NfX194evrW+n2QggsW7YMI0aMqHJIoetyC0sw9vMDOHDxKhzsVIh9oi2GdOYjMSIiJbC3U2H6Y63RyMsF7/xwCl/tS0Xa1SLMG9YR7lyl/o6kx7lff/0VycnJePHFF2WXYpPOZxdg8PxdOHDxKtycHPD5C10ZUoiIFCiqezA+Hd4Zzo72+P3sFQxZEI/0vCLZZSme9KCyZMkSREZGIjQ0VHYpNmefZejbNTT0dMbaCZGIbOojuywiIrqN3q388c34CPi5aXA6Kx+D5u3C8TS97LIUTdqEb9ZSV/uobDiSjte+OYYSkxntdB74bERn+LppZJdFRESVcCmvCC8s34/EzHw4O9rj/4Z2wMOt/GWXVaMUP+Eb3R0hBOb+chYvxx1BicmMvq39ETemG0MKEZENqe/hjG/GR+D+5r4oKjVh7MoDWLYrWXZZisSgYkNKysx4bc0xfLjlDABgzH3BWDCsE5zV9pIrIyKiqnJzcsSSkZ0xtGsQhABmfH8Kb393EiazTT/osDoGFRuhLypF1LJ9WHMwDXYq4N1BbfDv/q1gZ8fhx0REtsrR3g4fDG6DmH7X+2ku330B41YewLWSMsmVKQeDig1Izb2GJxfsxu5zOXBV22PJyC4Y3q2R7LKIiMgKVCoVxvdsgnnPdYTawQ5bEy7j6UXxuGwoll2aIjCoKNyR1DwMnr8LSZcL4O+uwerxEXgg1E92WUREZGX9wwLx1Zhu8HZV40S6AYPm7UJiJtezY1BRsJ9OZOLZT+NxpaAELQOvrxvRur5WdllERFRNOjXyxLqJ3RHi64pL+mI8tSAev53Jll2WVAwqCiSEwOLfzmPClwdRXGrGAy188c34CARqnWWXRkRE1SzI2wVrJ0QiPNgLBcYyjFq+H1/tS5FdljQMKgpTZjJj2oYTeP/HBAgBPN8tCItHdEY9jdVXOyAiIoXycFFj5ehwPNGhAUxmgalrj2PmpkSY6+CIIAYVBSkwluHFzw/giz0pUKmAN/u3xLsD28Chji9cRURUF6kd7PDh0+0wuXczAMDCHefw0leHUVxqklxZzZL6CXjmzBkMHDgQPj4+cHd3R48ePbBt2zaZJUmTqS/GkIXx2H46G06OdlgwrBNevC+Eqx8TEdVhKpUKk3s3x3+faQdHexU2Hs/Ac4v3IKfAKLu0GiM1qAwYMABlZWX49ddfcfDgQbRr1w4DBgxAZmamzLJq3KlL13t3J2QY4FNPjbixEXikTYDssoiISCEGd2iIlaPDoXV2xKGUPAyevxvnsgtkl1UjpK31c+XKFfj6+uK3337DfffdBwDIz8+Hu7s7tmzZgt69e9/yOKPRCKPxzyRpMBig0+lsdq2fbYmX8Y9Vh1BYYkJTv3pYFtUFOi8X2WUREZECncsuwKhl+5GSew1aZ0csGt4J3UK8ZZd1VxS/1o+3tzdatGiBzz//HIWFhSgrK8OiRYvg5+eHTp063fa42NhYaLVay6bT6WqwautaueciRq/Yj8ISEyKbeOPbCZEMKUREdFtNfOth3cRIdAjygL6oFMOX7MW6w2myy6pWUldPTktLw6BBg3Do0CHY2dnBz88PGzduRIcOHW57TG24o2I2C8RuSsDi368vQPVUp4b4YHBbqB3YaZaIiO6suNSEV1cfxcbjGQCAyb2b4eWHmtlUv0Zpd1RiYmKgUqkq3BITEyGEQHR0NPz8/PD7779j3759GDRoEB577DFkZGTc9v01Gg3c3d3LbbakqMSECV8etISUf/ZpjtlPhTGkEBFRpTk52mPu0A4Y37MJAODjrWfx6uqjKCkzS67M+qx+RyU7Oxs5OTkVtgkJCcHvv/+OPn364OrVq+XCRrNmzTB69GjExMRU6vtVNpEpQXa+ES9+fgBHU/OgtrfD7CFhGNi+geyyiIjIhn21LwVvrj8Bk1mgW4gXFj3fGVoXR9ll3VFlP7+tPouYr68vfH1979ju2rVrAAA7u/J3Euzs7GA2175EeDYrH1HL9iM9rwgeLo74dHhndA32kl0WERHZuKFdg1DfwxnRXx7CnvO5GLxgF5ZHdUWQd+3o8yjteUNERAQ8PT0xcuRIHD16FGfOnMFrr72G5ORk9O/fX1ZZ1WJ30hU8sWA30vOK0PiPqZEZUoiIyFp6Nr+x1IoTzmcXYvD8XTiUclV2WVYhLaj4+Pjgp59+QkFBAR588EF07twZO3fuxIYNG9CuXTtZZVndNwdSMWLpPuQXl6FzI0+sndgdIb71ZJdFRES1zJ+L17ojp7AEQz/dgx+P377Pp62QOurHGpTaR0UIgY+2nMHcX5MAAAPCAjFnSDs4OdpLroyIiGqzQmMZXo47jK0JlwEAU/uFYuz9ypvpXPHzqNRmxjITJn99xBJSoh9ogv97tgNDChERVTtXjQMWDe+MqMjGAIDYTYn49/oTKDPZZv9PLslrZVcLSzBu5UHsu5ALBzsVPhjcFk93sd1J6YiIyPbY26nw9uOt0cjbBe/8cAqr9qYg/WoRPnmuA9yclD8i6K94R8WKLlwpxBMLdmPfhVy4aRywfFRXhhQiIpJmVPdgfDq8M5wd7bHjTDaGLIzHpbwi2WVVCYOKlRy4kIvB83ch+UohGng449uJkejRzEd2WUREVMc93Mofq8dFwNdNg8TMfAyevwsn0vWyy6o0BhUr+P7oJTz32V5cvVaKsIZarIuORHN/N9llERERAQDaNtRi3cRINPevhyyDEU8viscvCVmyy6oUBpV7IITAvG1JeOmrwygpM+PhVv6IG9sNfm5OsksjIiIqp6GnC9ZMiMR9zXxwrcSEMZ8fwIrdF2SXdUcMKnep1GRGzLfHMXvzaQDA6B7BWPh8J7io2T+ZiIiUyd3JEUujuuDZLjqYBTD9u5N45/tTMJmVO1MJP1Xvgr6oFBO/PIhdSTmwUwHTH2uNkX8MAyMiIlIyR3s7xD7RFkHeLvjPT6exdFcyUq9ew/+eba/I/9mWekfl0KFDePjhh+Hh4QFvb2+MHTsWBQUFMku6o7Sr1zBk4W7sSsqBi9oei0d0ZkghIiKbolKpMLFXU8wd2gFqBztsOZWFZz/dg8v5xbJLu4m0oHLp0iX07t0bTZs2xd69e/HTTz/h5MmTiIqKklXSHR1Ly8Pg+btxJqsAfm4arB4XgYda+ssui4iI6K481q4+Vr0YDk8XRxxL02PwvN04k5Uvu6xypE2h/+mnn2LatGnIyMiwrKB8/PhxhIWF4ezZs2jatGml3qemptD/+WQmJsUdRnGpGaEBblga1QX1PZyr7fsRERHVlAtXCvHC8v04f6UQbhoHLHi+U7VPsaH4KfSNRiPUarUlpACAs/P1D/6dO3dWeJzBYCi3VSchBJbsTMa4Lw6iuNSM+/9YoZIhhYiIaovGPq74dkIkujb2Qr6xDFHL9mH1/lTZZQGQGFQefPBBZGZmYvbs2SgpKcHVq1cRExMDAMjIuP1qj7GxsdBqtZZNp6u+mV9NZoG3vzuJd384BSGAoV2DsGRkZ5ubfpiIiOhOPF3VWPliVwxsXx9lZoHXvz2G2ZsTYZY8IsjqQSUmJgYqlarCLTExEa1bt8aKFSvw4YcfwsXFBQEBAQgODoa/v3+5uyx/N3XqVOj1esuWmlo9ia/QWIaxnx/AiviL179vv1B8MLgNHO05opuIiGonjYM9Pn6mPSY9eL37xbxt5/Dy10dQXGqSVpPV+6hkZ2cjJyenwjYhISFQq9WW11lZWXB1dYVKpYK7uzvi4uIwZMiQSn2/6uqj8uKK/diacBkaBzv895n2eLRtoNXem4iISOm+OZCKqWuPo8ws8ESHBvjomfZWff/Kfn5bfcC0r68vfH19q3SMv//1kTNLly6Fk5MTHn74YWuXVWVTHm6BM1kF+PjZ9ugY5Cm7HCIioho1pLMODTycEbP2OP7xYOUGuFQHaaN+AOCTTz5BZGQk6tWrhy1btuC1117DzJkzMWnSpEq/R3WO+ik1mfmoh4iI6rTq+iyUdkelKvbt24fp06ejoKAAoaGhWLRoEYYPHy6zpHIYUoiIqK6T/VkoNah8/vnnMr89ERERKRxvGRAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWJVW1B5//33ERkZCRcXF3h4eNyyTUpKCvr37w8XFxf4+fnhtddeQ1lZWXWVRERERDam2qbQLykpwZAhQxAREYElS5bc9HWTyYT+/fsjICAAu3fvRkZGBkaMGAFHR0d88MEH1VUWERER2ZBqXz15+fLlmDx5MvLy8srt37RpEwYMGIBLly7B398fALBw4UK88cYbyM7OhlqtvuX7GY1GGI1Gy2u9Xo+goCCkpqZaffVkIiIiqh4GgwE6nQ55eXnQarW3bSdtUcL4+Hi0bdvWElIAoG/fvpgwYQJOnjyJDh063PK42NhYzJgx46b9Op2u2molIiKi6pGfn6/MoJKZmVkupACwvM7MzLztcVOnTsWUKVMsr81mM3Jzc+Ht7Q2VSmXVGm+kvdp6t4bnZ/tq+zny/GxfbT9Hnt/dE0IgPz8f9evXr7BdlYJKTEwMZs2aVWGbhIQEhIaGVuVtq0Sj0UCj0ZTbd7vOutbi7u5eK/8B3sDzs321/Rx5fravtp8jz+/uVHQn5YYqBZVXX30VUVFRFbYJCQmp1HsFBARg37595fZlZWVZvkZERERUpaDi6+sLX19fq3zjiIgIvP/++7h8+TL8/PwAAFu2bIG7uztatWplle9BREREtq3a+qikpKQgNzcXKSkpMJlMOHLkCACgadOmqFevHvr06YNWrVph+PDh+M9//oPMzEy8+eabiI6OvunRjiwajQbTp09XTD3WxvOzfbX9HHl+tq+2nyPPr/pV2/DkqKgorFix4qb927ZtQ69evQAAFy9exIQJE7B9+3a4urpi5MiRmDlzJhwcpPXxJSIiIgWp9nlUiIiIiO4W1/ohIiIixWJQISIiIsViUCEiIiLFYlAhIiIixarTQeX9999HZGQkXFxcKj27rRACb731FgIDA+Hs7IzevXvj7Nmz5drk5uZi2LBhcHd3h4eHB0aPHo2CgoJqOIOKVbWOCxcuQKVS3XL75ptvLO1u9fW4uLiaOKWb3M3PulevXjfVP378+HJtUlJS0L9/f7i4uMDPzw+vvfYaysrKqvNUbqmq55ebm4uXXnoJLVq0gLOzM4KCgjBp0iTo9fpy7WRew3nz5qFx48ZwcnJCeHj4TRM//t0333yD0NBQODk5oW3btvjxxx/Lfb0yv5M1qSrnt3jxYtx3333w9PSEp6cnevfufVP7qKiom67VI488Ut2ncVtVOb/ly5ffVLuTk1O5Nkq7fkDVzvFWf09UKhX69+9vaaOka/jbb7/hscceQ/369aFSqbB+/fo7HrN9+3Z07NgRGo0GTZs2xfLly29qU9Xf6yoRddhbb70lPvroIzFlyhSh1WordczMmTOFVqsV69evF0ePHhWPP/64CA4OFkVFRZY2jzzyiGjXrp3Ys2eP+P3330XTpk3F0KFDq+ksbq+qdZSVlYmMjIxy24wZM0S9evVEfn6+pR0AsWzZsnLt/nr+NeluftY9e/YUY8aMKVe/Xq+3fL2srEy0adNG9O7dWxw+fFj8+OOPwsfHR0ydOrW6T+cmVT2/48ePiyeeeEJ89913IikpSfzyyy+iWbNm4sknnyzXTtY1jIuLE2q1WixdulScPHlSjBkzRnh4eIisrKxbtt+1a5ewt7cX//nPf8SpU6fEm2++KRwdHcXx48ctbSrzO1lTqnp+zz33nJg3b544fPiwSEhIEFFRUUKr1Yq0tDRLm5EjR4pHHnmk3LXKzc2tqVMqp6rnt2zZMuHu7l6u9szMzHJtlHT9hKj6Oebk5JQ7vxMnTgh7e3uxbNkySxslXcMff/xR/Pvf/xZr164VAMS6desqbH/+/Hnh4uIipkyZIk6dOiXmzp0r7O3txU8//WRpU9WfWVXV6aByw7JlyyoVVMxmswgICBCzZ8+27MvLyxMajUZ89dVXQgghTp06JQCI/fv3W9ps2rRJqFQqkZ6ebvXab8dadbRv31688MIL5fZV5h93Tbjbc+zZs6d4+eWXb/v1H3/8UdjZ2ZX7g7pgwQLh7u4ujEajVWqvDGtdw9WrVwu1Wi1KS0st+2Rdw65du4ro6GjLa5PJJOrXry9iY2Nv2f7pp58W/fv3L7cvPDxcjBs3TghRud/JmlTV8/u7srIy4ebmJlasWGHZN3LkSDFw4EBrl3pXqnp+d/rbqrTrJ8S9X8P//ve/ws3NTRQUFFj2Keka/lVl/g68/vrronXr1uX2PfPMM6Jv376W1/f6M7uTOv3op6qSk5ORmZmJ3r17W/ZptVqEh4cjPj4eABAfHw8PDw907tzZ0qZ3796ws7PD3r17a6xWa9Rx8OBBHDlyBKNHj77pa9HR0fDx8UHXrl2xdOlSCAnT8dzLOX755Zfw8fFBmzZtMHXqVFy7dq3c+7Zt27bc6t59+/aFwWDAyZMnrX8it2Gtf0t6vR7u7u43TaRY09ewpKQEBw8eLPf7Y2dnh969e1t+f/4uPj6+XHvg+rW40b4yv5M15W7O7++uXbuG0tJSeHl5ldu/fft2+Pn5oUWLFpgwYQJycnKsWntl3O35FRQUoFGjRtDpdBg4cGC53yElXT/AOtdwyZIlePbZZ+Hq6lpuvxKu4d240++gNX5md8IpYKsgMzMTAMp9gN14feNrmZmZlrWLbnBwcICXl5elTU2wRh1LlixBy5YtERkZWW7/O++8gwcffBAuLi74+eefMXHiRBQUFGDSpElWq78y7vYcn3vuOTRq1Aj169fHsWPH8MYbb+D06dNYu3at5X1vdY1vfK2mWOMaXrlyBe+++y7Gjh1bbr+Ma3jlyhWYTKZb/mwTExNvecztrsVff99u7Ltdm5pyN+f3d2+88Qbq169f7o/+I488gieeeALBwcE4d+4c/vWvf6Ffv36Ij4+Hvb29Vc+hIndzfi1atMDSpUsRFhYGvV6POXPmIDIyEidPnkTDhg0Vdf2Ae7+G+/btw4kTJ7BkyZJy+5VyDe/G7X4HDQYDioqKcPXq1Xv+d38ntS6oxMTEYNasWRW2SUhIQGhoaA1VZF2VPb97VVRUhFWrVmHatGk3fe2v+zp06IDCwkLMnj3bah9y1X2Of/3Qbtu2LQIDA/HQQw/h3LlzaNKkyV2/b2XV1DU0GAzo378/WrVqhbfffrvc16r7GlLVzZw5E3Fxcdi+fXu5DqfPPvus5b/btm2LsLAwNGnSBNu3b8dDDz0ko9RKi4iIQEREhOV1ZGQkWrZsiUWLFuHdd9+VWFn1WLJkCdq2bYuuXbuW22/L11AJal1QefXVVxEVFVVhm5CQkLt674CAAABAVlYWAgMDLfuzsrLQvn17S5vLly+XO66srAy5ubmW4+9FZc/vXutYs2YNrl27hhEjRtyxbXh4ON59910YjUarLFxVU+d4Q3h4OAAgKSkJTZo0QUBAwE091rOysgDAZq5hfn4+HnnkEbi5uWHdunVwdHSssL21r+Gt+Pj4wN7e3vKzvCErK+u25xMQEFBh+8r8TtaUuzm/G+bMmYOZM2di69atCAsLq7BtSEgIfHx8kJSUVKMfcvdyfjc4OjqiQ4cOSEpKAqCs6wfc2zkWFhYiLi4O77zzzh2/j6xreDdu9zvo7u4OZ2dn2Nvb3/O/izuySk8XG1fVzrRz5syx7NPr9bfsTHvgwAFLm82bN0vrTHu3dfTs2fOmkSK389577wlPT8+7rvVuWetnvXPnTgFAHD16VAjxZ2fav/ZYX7RokXB3dxfFxcXWO4E7uNvz0+v1olu3bqJnz56isLCwUt+rpq5h165dxT/+8Q/La5PJJBo0aFBhZ9oBAwaU2xcREXFTZ9qKfidrUlXPTwghZs2aJdzd3UV8fHylvkdqaqpQqVRiw4YN91xvVd3N+f1VWVmZaNGihXjllVeEEMq7fkLc/TkuW7ZMaDQaceXKlTt+D5nX8K9Qyc60bdq0Kbdv6NChN3WmvZd/F3es0yrvYqMuXrwoDh8+bBmCe/jwYXH48OFyQ3FbtGgh1q5da3k9c+ZM4eHhITZs2CCOHTsmBg4ceMvhyR06dBB79+4VO3fuFM2aNZM2PLmiOtLS0kSLFi3E3r17yx139uxZoVKpxKZNm256z++++04sXrxYHD9+XJw9e1bMnz9fuLi4iLfeeqvaz+dWqnqOSUlJ4p133hEHDhwQycnJYsOGDSIkJETcf//9lmNuDE/u06ePOHLkiPjpp5+Er6+vtOHJVTk/vV4vwsPDRdu2bUVSUlK54ZBlZWVCCLnXMC4uTmg0GrF8+XJx6tQpMXbsWOHh4WEZYTV8+HARExNjab9r1y7h4OAg5syZIxISEsT06dNvOTz5Tr+TNaWq5zdz5kyhVqvFmjVryl2rG3+D8vPzxT//+U8RHx8vkpOTxdatW0XHjh1Fs2bNajQ03+35zZgxQ2zevFmcO3dOHDx4UDz77LPCyclJnDx50tJGSddPiKqf4w09evQQzzzzzE37lXYN8/PzLZ91AMRHH30kDh8+LC5evCiEECImJkYMHz7c0v7G8OTXXntNJCQkiHnz5t1yeHJFP7N7VaeDysiRIwWAm7Zt27ZZ2uCP+SZuMJvNYtq0acLf319oNBrx0EMPidOnT5d735ycHDF06FBRr1494e7uLkaNGlUu/NSUO9WRnJx80/kKIcTUqVOFTqcTJpPppvfctGmTaN++vahXr55wdXUV7dq1EwsXLrxl25pQ1XNMSUkR999/v/Dy8hIajUY0bdpUvPbaa+XmURFCiAsXLoh+/foJZ2dn4ePjI1599dVyw3trSlXPb9u2bbf8Nw1AJCcnCyHkX8O5c+eKoKAgoVarRdeuXcWePXssX+vZs6cYOXJkufarV68WzZs3F2q1WrRu3Vps3Lix3Ncr8ztZk6pyfo0aNbrltZo+fboQQohr166JPn36CF9fX+Ho6CgaNWokxowZY7UPgLtRlfObPHmypa2/v7949NFHxaFDh8q9n9KunxBV/zeamJgoAIiff/75pvdS2jW83d+IG+c0cuRI0bNnz5uOad++vVCr1SIkJKTcZ+INFf3M7pVKCAnjSomIiIgqgfOoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFi/T8qIITBChX6pQAAAABJRU5ErkJggg==",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# log(1/(1 + exp(-x - eps) - 1/(1 + exp(-x + eps)))\n",
"# plot \n",
"inv_scale = 10\n",
"f = lambda x : np.log(1/(1 + np.exp(- inv_scale * (x + 0.1))) - 1/(1 + np.exp(- inv_scale* (x - 0.1))))\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"x = np.linspace(-1, 1, 100)\n",
"y = f(x)\n",
"plt.plot(x, y)\n",
"plt.yticks(np.arange(-10, 10, 1))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100,\n",
" 0.9100])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_classes = 10\n",
"beta = 0.1\n",
"mat = torch.ones(num_classes, num_classes) * beta / num_classes\n",
"# diagonal is 1 - (K - 1)/K * beta\n",
"mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100,\n",
" 0.0100],\n",
" [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n",
" 0.9100]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mat"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 1, 32, 32, 16])\n"
]
}
],
"source": [
"from typing import Dict, Tuple\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"from torchvision.utils import save_image, make_grid\n",
"\n",
"\n",
"blk = lambda ic, oc: nn.Sequential(\n",
" nn.Conv2d(ic, oc, 7, padding=3),\n",
" nn.GroupNorm(oc // 8, oc, affine = False, eps = 1e-4),\n",
" nn.LeakyReLU(),\n",
")\n",
"\n",
"\n",
"class DummyX0Model(nn.Module):\n",
" \"\"\"\n",
" This should be unet-like, but let's don't think about the model too much :P\n",
" Basically, any universal R^n -> R^n model should work.\n",
" \"\"\"\n",
"\n",
" def __init__(self, n_channel: int, N) -> None:\n",
" super(DummyX0Model, self).__init__()\n",
" self.start = blk(n_channel, 16)\n",
" self.pe = nn.Parameter(torch.randn(1, 16, 32, 32))\n",
" self.conv = nn.Sequential(\n",
" blk(16, 128),\n",
" blk(128, 256),\n",
" blk(256, 512),\n",
" blk(512, 256),\n",
" blk(256, 128),\n",
" blk(128, 64),\n",
" nn.Conv2d(64, 2 * n_channel, 3, padding=1),\n",
" )\n",
" self.N = N\n",
"\n",
" def forward(self, x, t) -> torch.Tensor:\n",
" # Lets think about using t later. In the paper, they used Tr-like positional embeddings.\n",
" st = x.float() / self.N - 0.5\n",
" x = self.start(st) + self.pe\n",
" y = self.conv(x)\n",
" loc, log_scale = y.chunk(2, dim=1)\n",
"\n",
" return torch.tanh(loc + st), log_scale\n",
"\n",
"blk = lambda ic, oc: nn.Sequential(\n",
" nn.Conv2d(ic, oc, 5, padding=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
" nn.Conv2d(oc, oc, 5, padding=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
" nn.Conv2d(oc, oc, 5, padding=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
")\n",
"\n",
"blku = lambda ic, oc: nn.Sequential(\n",
" nn.Conv2d(ic, oc, 5, padding=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
" nn.Conv2d(oc, oc, 5, padding=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
" nn.Conv2d(oc, oc, 5, padding=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
" nn.ConvTranspose2d(oc, oc, 2, stride=2),\n",
" nn.GroupNorm(oc // 8, oc),\n",
" nn.LeakyReLU(),\n",
" \n",
")\n",
"\n",
"class DummyX0Model(nn.Module):\n",
" \n",
"\n",
" def __init__(self, n_channel: int, N : int = 16) -> None:\n",
" super(DummyX0Model, self).__init__()\n",
" self.down1 = blk(n_channel, 16)\n",
" self.down2 = blk(16, 32)\n",
" self.down3 = blk(32, 64)\n",
" self.down4 = blk(64, 512)\n",
" self.down5 = blk(512, 512)\n",
" self.up1 = blku(512, 512)\n",
" self.up2 = blku(512 + 512, 64) # Corrected to account for concatenated feature maps\n",
" self.up3 = blku(64 + 64, 32) # Corrected to account for concatenated feature maps\n",
" self.up4 = blku(32 + 32, 16) # Corrected to account for concatenated feature maps\n",
" self.convlast = blk(32, 16)\n",
" self.final = nn.Conv2d(16, N * n_channel, 1, bias = False)\n",
"\n",
" # initialize final with zero\n",
" #self.final.weight.data.zero_()\n",
" #self.final.bias.data.zero_()\n",
"\n",
" self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n",
" self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n",
" self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8)\n",
"\n",
" self.temb_1 = nn.Linear(32, 16)\n",
" self.temb_2 = nn.Linear(32, 32)\n",
" self.temb_3 = nn.Linear(32, 64)\n",
" self.temb_4 = nn.Linear(32, 512)\n",
" self.N = N\n",
"\n",
" def forward(self, x, t) -> torch.Tensor:\n",
" x = (2 * x.float() / self.N) - 1.0\n",
" t = t.float().reshape(-1, 1) / 500\n",
" t_as_sin = [torch.sin(t * 3.1415 * 2 ** i) for i in range(16)]\n",
" t_as_cos = [torch.cos(t * 3.1415 * 2 ** i) for i in range(16)]\n",
" # concat and send it to t_emb\n",
" t_emb_1 = self.temb_1(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
" t_emb_2 = self.temb_2(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
" t_emb_3 = self.temb_3(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
" t_emb_4 = self.temb_4(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n",
" \n",
" x1 = self.down1(x) + t_emb_1\n",
" x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2\n",
" x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3\n",
" x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4\n",
" x5 = self.down5(nn.functional.avg_pool2d(x4, 2))\n",
"\n",
" x5 = self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(x5.shape)\n",
"\n",
" y = self.up1(x5)\n",
"\n",
" y = self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(y.shape)\n",
" \n",
" y = self.up2(torch.cat([x4, y], dim=1))\n",
"\n",
" y = self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(y.shape)\n",
"\n",
" y = self.up3(torch.cat([x3, y], dim=1))\n",
"\n",
" y = self.up4(torch.cat([x2, y], dim=1))\n",
"\n",
"\n",
" y = self.convlast(torch.cat([x1, y], dim=1))\n",
" y = self.final(y)\n",
" # reshape to B, C, H, W, N\n",
" y = y.reshape(y.shape[0], -1, self.N, *x.shape[2:]).transpose(2, -1).contiguous()\n",
" \n",
" return y\n",
"\n",
"\n",
"# check if it takes 2, 1, 28, 28 input\n",
"\n",
"model = DummyX0Model(1, N = 16)\n",
"print(model(torch.randn(2, 1, 32, 32), torch.tensor([1, 2])).shape)\n",
"\n",
"\n",
"\n",
"def get_logits_from_logistic_pars(loc, log_scale, num_classes = 10):\n",
" loc = loc.unsqueeze(-1)\n",
" log_scale = log_scale.unsqueeze(-1)\n",
" inv_scale = (-log_scale + 2.0).exp()\n",
"\n",
" bin_width = 2.0 / (num_classes - 1)\n",
" bin_centers = torch.linspace(-1.0, 1.0, num_classes).to(loc.device)\n",
" bin_centers = bin_centers.reshape([1] * (len(loc.shape) - 1) + [num_classes])\n",
" bin_centers = bin_centers - loc\n",
" log_cdf_min = -torch.log1p((-inv_scale * (bin_centers - 0.5 * bin_width)).exp())\n",
" log_cdf_plus = -torch.log1p((-inv_scale * (bin_centers + 0.5 * bin_width)).exp())\n",
" logits = log_minus_exp(log_cdf_plus, log_cdf_min)\n",
" return logits\n",
"\n",
"\n",
"def log_minus_exp(a, b, epsilon=1e-6):\n",
" return a + torch.log1p(-torch.exp(b - a) + epsilon)\n",
"\n",
"\n",
"\n",
"class D3PM(nn.Module):\n",
" def __init__(\n",
" self,\n",
" x0_model: nn.Module,\n",
" n_T: int,\n",
" num_classes: int = 10,\n",
" forward_type = 'uniform',\n",
" hybrid_loss_coeff = 0.001\n",
" ) -> None:\n",
" super(D3PM, self).__init__()\n",
" self.x0_model = x0_model\n",
"\n",
" self.n_T = n_T\n",
" self.hybrid_loss_coeff = hybrid_loss_coeff\n",
" self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)]\n",
" self.eps = 1e-8\n",
" self.num_classses = num_classes\n",
" q_onestep_mats = []\n",
" q_mats = [] # these are cumulative\n",
"\n",
" for beta in self.beta_t:\n",
"\n",
" if forward_type == 'uniform':\n",
" mat = torch.ones(num_classes, num_classes) * beta / num_classes\n",
" mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)\n",
" q_onestep_mats.append(mat)\n",
" else:\n",
" raise NotImplementedError\n",
" q_one_step_mats = torch.stack(q_onestep_mats, dim=0)\n",
"\n",
" q_one_step_transposed = q_one_step_mats.transpose(1, 2) # this will be used for q_posterior_logits\n",
"\n",
" q_mat_t = q_onestep_mats[0]\n",
" q_mats = [q_mat_t]\n",
" for idx in range(1, self.n_T):\n",
" q_mat_t = q_mat_t @ q_onestep_mats[idx]\n",
" q_mats.append(q_mat_t)\n",
" q_mats = torch.stack(q_mats, dim=0)\n",
" self.logit_type = 'logit'\n",
"\n",
" # register\n",
" self.register_buffer(\"q_one_step_transposed\", q_one_step_transposed)\n",
" self.register_buffer(\"q_mats\", q_mats)\n",
"\n",
" assert self.q_mats.shape == (self.n_T, num_classes, num_classes), self.q_mats.shape\n",
" \n",
" def _at(self, a, t, x):\n",
" # t is 1-d, x is integer value of 0 to num_classes - 1\n",
" bs = t.shape[0]\n",
" t = t.reshape((bs, *[1] * (x.dim() - 1)))\n",
" #out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]\n",
" return a[t - 1, x, :]\n",
"\n",
"\n",
" def q_posterior_logits(self, x_0, x_t, t):\n",
" # if t == 1, this means we return the L_0 loss, so directly try to x_0 logits.\n",
" # otherwise, we return the L_{t-1} loss.\n",
" # Also, we never have t == 0.\n",
"\n",
" # if x_0 is integer, we convert it to one-hot.\n",
" if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:\n",
" x_0_logits = torch.log(torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps)\n",
" else:\n",
" x_0_logits = x_0.clone()\n",
"\n",
" assert x_0_logits.shape == x_t.shape + (self.num_classses,), print(f\"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}\")\n",
"\n",
" # Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that.\n",
" # fact1 is \"guess of x_{t-1}\" from x_t\n",
" # fact2 is \"guess of x_{t-1}\" from x_0\n",
"\n",
" fact1 = self._at(self.q_one_step_transposed, t, x_t)\n",
" #fact2 = self._at_onehot(self.q_mats, t-1, )\n",
" # x, a[t-1]\n",
"\n",
"\n",
" softmaxed = torch.softmax(x_0_logits, dim=-1) # bs, ..., num_classes\n",
" qmats2 = self.q_mats[t-2] # bs, num_classes, num_classes\n",
"\n",
" fact2 = torch.einsum('b...c,bcd->b...d', softmaxed, qmats2)\n",
" \n",
" #print(f\"Fact1Fact2\", fact1.shape, fact2.shape)\n",
" out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)\n",
" #print(f\"out: {out.shape}\")\n",
" t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))\n",
" #print(t_broadcast.shape)\n",
" bc = torch.where(t_broadcast == 1, x_0_logits, out)\n",
" #print(f\"bc: {bc.shape}\")\n",
" return bc\n",
"\n",
" def vb(self, dist1, dist2):\n",
" out = (torch.softmax(dist1 + self.eps, dim = -1)*(torch.log_softmax(dist1 + self.eps, dim = -1) - torch.log_softmax(dist2 + self.eps, dim = -1)))\n",
" return out.sum(dim=-1).mean()\n",
"\n",
" \n",
" def q_sample(self, x_0, t, noise):\n",
" # forward process, x_0 is the clean input.\n",
" \n",
" logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)\n",
" noise = torch.clip(noise, self.eps, 1.0)\n",
" gumbel_noise = -torch.log(-torch.log(noise))\n",
" return torch.argmax(logits + gumbel_noise, dim=-1)\n",
"\n",
" def model_predict(self, x_0, t):\n",
" if self.logit_type == 'logit':\n",
" predicted_x0_logits = self.x0_model(x_0, t)\n",
" else:\n",
" loc, log_scale = self.x0_model(x_0, t)\n",
" predicted_x0_logits = get_logits_from_logistic_pars(loc, log_scale, self.num_classses)\n",
" return predicted_x0_logits\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model.\n",
" x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1\n",
" \"\"\"\n",
" t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)\n",
" x_t = self.q_sample(x, t, torch.rand((*x.shape, self.num_classses), device=x.device))\n",
" # x_t is same shape as x\n",
" assert x_t.shape == x.shape, print(f\"x_t.shape: {x_t.shape}, x.shape: {x.shape}\")\n",
" # we use hybrid loss.\n",
" \n",
" predicted_x0_logits = self.model_predict(x, t)\n",
"\n",
"\n",
" # based on this, we first do vb loss.\n",
" true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)\n",
" #print(f\"predicted_x0_logits: {predicted_x0_logits.shape}, true_q_posterior_logits: {true_q_posterior_logits.shape}\")\n",
" pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)\n",
"\n",
" vb_loss = self.vb(true_q_posterior_logits,pred_q_posterior_logits)\n",
"\n",
"\n",
"\n",
" predicted_x0_logits = predicted_x0_logits.flatten(start_dim = 0, end_dim = -2)\n",
" x = x.flatten(start_dim = 0, end_dim = -1)\n",
" #print(f\"predicted_x0_logits: {predicted_x0_logits.shape}, x: {x.shape}\")\n",
"\n",
"\n",
" ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)\n",
"\n",
" return vb_loss + ce_loss*self.hybrid_loss_coeff, {\"vb_loss\": vb_loss.detach().item(), \"ce_loss\": ce_loss.detach().item()}\n",
"\n",
" def p_sample(self, x, t, noise):\n",
" \n",
" predicted_x0_logits = self.model_predict(x, t)\n",
" pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)\n",
"\n",
" noise = torch.clip(noise, self.eps, 1.0)\n",
"\n",
" not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))\n",
" gumbel_noise = -torch.log(-torch.log(noise))\n",
" sample = torch.argmax(pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1)\n",
" return sample\n",
"\n",
" def sample(self, x = None):\n",
" for t in reversed(range(1, self.n_T)):\n",
" t = torch.tensor([t]*x.shape[0], device=x.device)\n",
" x = self.p_sample(x, t, torch.rand((*x.shape, self.num_classses), device=x.device))\n",
" \n",
" return x\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#d3pm = D3PM(DummyX0Model(1, 20), 1000, num_classes = 20).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#d3pm.sample(torch.randint(0, 2, (2, 1, 32, 32)).cuda())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# d3pm.train()\n",
"# d3pm.forward(torch.randint(0, 2, (128, 1, 28, 28)).cuda())\n",
"# dataset = MNIST(\n",
"# \"./data\",\n",
"# train=True,\n",
"# download=True,\n",
"# transform=transforms.ToTensor()\n",
"# )\n",
"# dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)\n",
"# d3pm(x = next(iter(dataloader))[0].cuda().long())\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Param Count: 63336704\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 30, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" warnings.warn(_create_warning_msg(\n"
]
}
],
"source": [
"N = 16\n",
"d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes = N, hybrid_loss_coeff=0.0).cuda()\n",
"print(f\"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}\")\n",
"dataset = MNIST(\n",
" \"./data\",\n",
" train=True,\n",
" download=True,\n",
" transform= transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Pad(2),\n",
" ])\n",
" )\n",
"dataloader = DataLoader([dataset[0]] * 50000, batch_size=32, shuffle=True, num_workers=32)\n",
"#dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32)\n",
"\n",
"optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=4e-4, betas=(0.95, 0.99))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1563 [00:00, ?it/s]/root/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 30, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" warnings.warn(_create_warning_msg(\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 0.0025, norm: 0.0052, param_norm: 923.3715, vb_loss: 0.0025, ce_loss: 2.5743: 0%| | 0/1563 [00:02, ?it/s]"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl50lEQVR4nO3dfXAV1fkH8G9CyA0KuSk4JKYQklZHfC00CEadKpoWqaOiTKsOrcE6ddTEGpmpilbbsUOD44xiO4jTjk3sVIqlI7G+oGMDYrHhLSW+lBpx5G3UG2qd5MYXAianf/jj/s4ecp/ds7t3703y/cxkJje7d/fcs7uXw3nOeU6eUkqBiIiIKCL52S4AERERjS5sfBAREVGk2PggIiKiSLHxQURERJFi44OIiIgixcYHERERRYqNDyIiIooUGx9EREQUKTY+iIiIKFJsfBAREVGkMtb4WLlyJSorK1FUVIQ5c+Zg27ZtmToVERERDSN5mVjb5amnnsJ1112Hxx57DHPmzMGKFSuwdu1adHV1YfLkyeJ7BwcH8cEHH2DChAnIy8sLu2hERESUAUop9PX1oby8HPn5Ln0bKgNmz56t6uvrU68HBgZUeXm5ampqcn3vgQMHFAD+8Ic//OEPf/gzDH8OHDjg+m996GGXw4cPo6OjA7W1tam/5efno7a2Fu3t7cfs39/fj2QymfpRXGSXiIho2JowYYLrPqE3Pj766CMMDAygtLTU8ffS0lIkEolj9m9qakI8Hk/9VFRUhF0kIiIiioiXIRNZn+2ydOlS9Pb2pn4OHDiQ7SIRERFRBhWEfcATTjgBY8aMQXd3t+Pv3d3dKCsrO2b/WCyGWCwWdjGIiIgoR4Xe81FYWIjq6mq0tbWl/jY4OIi2tjbU1NSEfToiIiIaZkLv+QCAJUuWoK6uDrNmzcLs2bOxYsUKfPrpp7j++usDH3vdunUhlBCorKx0vN67d28ox73yyitTv+/cuTMj5xhOenp6HK9LSkpSv+t1ZQpynfVra1PnjY2NjtcrVqzwXQb6f9J1BsJ7pv165ZVXHK8vvPDCSM+fqe+iKJjPzL59+9Lum+3rbCOsa6J//+nffcOd2zPtRUYaH1dffTX+85//4L777kMikcCMGTPw4osvHjMIlYiIiEafjDQ+AKChoQENDQ2ZOjwRERENU1mf7UJERESjS8Z6PjJl8eLFjtctLS1p9zVjufp7zRievs18n7mvNK5Dj2uasWP9uDNnzhyyzEMdx4yvxePxtO/r7e11vDbLp9M/l1mvEvMcUiw3ijinVD9uZZWuu35vmfVj3hN6jNi8tlJ59Osj3Wfmccz73uv5TdK97vas6dul59CNdO9JxzXj8tJYDf045vuk4/od/yFdH8B5jaTr43Zcr9fA7fnW70vpGTHPYY6LksYCSMeR7jVp/IXb2Axp/JfNddafE5t7QvoOcbtH0nEbf2Lzb1m695mCPN/psOeDiIiIIsXGBxEREUUqI6vaBpFMJsVuSJtucynsYRMSMUMX+nvNbnSv7zO7sczuSv1zSt3L0vkBZ5euWT82Xb9SN2hnZ2fa95lhF33qmTT1WrrOgPd6N69lc3Oz47U+VdCsH51bmEyil0fqsrUJ95n8huJsnhGT3uXu9zoD3sNkbqEDqWvYb92Zx9Q/lznNVPoOMc+hf07pnnQjdevrz+nrr78ulsfrOdxI9Txt2rTU7+YUfPO6S99b+n0pfaeZgoSrpXN6DWXYfFdLpLoySf/OuJ1Per7dptr29vaiuLhY3Ic9H0RERBQpNj6IiIgoUmx8EBERUaSG3ZgPG9I0UykmK8X+TWZcXI/HSed3K48Ul9fPYdaVFAO1KY9UPvOcemzQjFeb40FmzJiR+t0tBqp7+OGHHa9vv/32tPtKn9OsSylmrrOJCZu8jidyG1fidcquNI7DjRTblvb1O97CJI2NcLvX05XNPI7bcyA9M9I2qe78jj9zI43/snkOpHPqz7D5fJtjAaT06l7HcZjbbb4bJdI53cZR2Izt80q6R8M6h3ROc+yeNNaIYz6IiIho2GPjg4iIiCLFxgcRERFFakSP+TD5HUdh8jpWwSZuaJMC22aeu99xAn7jjzapvW3OYZNnw+ZzeS2fW3m8jhmyuT4mm339Hsfm2up5G37xi184tukx4iBjPnQ2OR1sxjvYPJdel0hwG59iM95AOo7XMrjlppDyUejHMdOp244F8MNvniXAe10Gye0UxXMpsfmutHkuJG7l45gPIiIiyjlsfBAREVGkRlXYRRekq0qaduo1lGDT9Ssdx6270GsKdbfud72L0uxitwlZ6d22bmm3/ZK6L/3eW27H8dqFK3V3u51T4neqrc3UX5twgD49M0h6dZtr6XVabpBnxuu97nbtwgoBe52WaxN+dFspViJ9bmmlbpPXEESQ73Gb8HUYIWqb+85mKr/fcLpbaFCfemumTGDYhYiIiIYdNj6IiIgoUmx8EBERUaQKsl0AWzZptk1eU9SasTCbZdCl94UV0wuLFP+zmUbodUl04NipeukEuc5SfdmMf9C5XS/9tXSfmbHuTEyvcytrJu4nKW29G30Mkc11luL05meUxnHYjHsJa6yG1xTubmMRvN6zQZaMkFLB25DGjtiMndPZLEkg7ev2ufym3Pe6zdxuM94qyDl10jPc09Pj2BbGeD32fBAREVGk2PggIiKiSHGq7f/xO9VW2tdvtlHA+zSwsFb6NIU1vTjIdES/bFaK9drF7jZNTmfuu3HjxtTvCxYscGyzyaIqCStbrX7/mCEQ6V6Tumxtsl6GFVb1W3dAOCvFBskcGxYpw6nf7yZz5VPztddr7fYdkonMzJIgK5B7PU6QqbZ+s+mGNdXWDJlJqxcfPQ+n2hIREVFOYeODiIiIImXd+Hj11Vdx2WWXoby8HHl5eWhtbXVsV0rhvvvuw4knnohx48ahtrYWu3fvDqu8RERENMxZj/lYv349XnvtNVRXV+Oqq67CunXrHHHsBx54AE1NTXjiiSdQVVWFe++9F2+++SZ27dqFoqIi1+OHOebDb9zMJm4oxS6DrMboN0ZsUx79OM3NzY5t5lQqv7FuszxexwLYrCIbRUzYjdexPjarm/q9f9xiuX7HLJnjMfQpstJqp2Gl0s4GvzH8IOM4wjqOxG9qepM5FkC61vr94zaWJ+o6sJmK7Pe6m9+xM2bMcLz2+nwHWWHbZuyevt0s66ZNm9KW7+ix3MZ8WOf5mD9/PubPnz/kNqUUVqxYgZ/97Ge44oorAAB/+MMfUFpaitbWVlxzzTW2pyMiIqIRJtQxH3v27EEikUBtbW3qb/F4HHPmzEF7e/uQ7+nv70cymXT8EBER0cgVauMjkUgAAEpLSx1/Ly0tTW0zNTU1IR6Pp36mTp0aZpGIiIgox2Q9vfrSpUuxZMmS1OtkMhlaA8QmNieNjZCYYwr0OJmUptgtjul3fIo0bkA6ptsYDynviBQv1scFAPIS8pngNgbFZql1ide002Hlk5HitVJqcfO90hgUm9iyme/Baxp9kzSuxG1MjNexAUFSTuv1ZbNEgk1uHL9p/m1yi0ik8we5zvo4jyDp59PtFxWp7NJ1NuvO77gOt8/sNx2+9LnM6xxGvYfa81FWVgYA6O7udvy9u7s7tc0Ui8VQXFzs+CEiIqKRK9TGR1VVFcrKytDW1pb6WzKZxNatW1FTUxPmqYiIiGiYsp5q+8knn+Ddd98F8GXXy0MPPYS5c+di4sSJqKiowAMPPIDly5c7ptq+8cYbkUy19bsyIiBPzzT57day6WbzO61SOmcUKYTDqg830ufORIp5m9UzbVJFS/ymuLdZPVPidi31c5qrF+vhtbCmrttMDZTKagqrnvX7zgw3ZmMKcSbSzZtTLjs7Ox2vc22qdCbY/Pvk9ztfOk6QZRh0Ns+TzXT9o8cKfartjh07MHfu3NTro+M16urq0NLSgjvuuAOffvopbrzxRvT09OD888/Hiy++6KnhQURERCOfdePjwgsvhNRZkpeXh/vvvx/3339/oIIRERHRyMS1XYiIiChSWZ9qG5TXWPtQ23U20xr9ToOVYrB+p2BmKk27DSlG7pdbPFQaCyBN0zPXItLTgNvUR5AxF+n4jQGbgoxZslnWW99uxv5t+B0vM23aNMdr/drapOMPi1vKcF0U6cNteF2SwLzO+vNj7jucuP1b4fV62Xzf+p1SHWT8js2/HenOHxb2fBAREVGk2PggIiKiSLHxQURERJGyzvORaW55Psx8AvoSzkHyfITFbyzXJkeBzVLQYaWc9nrOsPJ8mGm2zXi633ECUnp1vzlAhnpvuuPY5HOxyWXid6yG39TrgDOXRWVlZdqySsusA/JS62GN1YhizIcuyJglv3lRwvpcNnlPTFIZ9O9ut3siCjZ5LPzmFbIZS2iz/IYuyLIMOmmci81349Htbnk+2PNBREREkWLjg4iIiCI17MIupkykrrYhdZ1Jenp6HK/NKWte06vbsOneDWu6qPm59PoxV9KVZCqk5jXMEFa3rE0IxOR13yDXWQpD7d271/FaD7WY+7a0tKR+t+lit+netu0KTsfvve63C9vmHEH4DT/67bYf6jw6r+E+t+OEJaywi9clLIKEdaWyhfle3QUXXJD63TaNPsMuRERElHPY+CAiIqJIsfFBREREkRp2Yz7cpmDq/MYRbd7ndxnvILFTryncAf/xa5tpnzpzjIc5tkU6h87mOgfhdQqxWwxYGiuh85u22WbfIGnapXvLnOa+ePHitMddsWJF6nebsT2mqKfImqTnKaz4vs35JTYp5aXn3Wabye+YD0mQ7+NcTvceZMyQxG/697DSJBzdzjEfRERElFPY+CAiIqJIDbuwi8mmW9ZmapVfXrvSbMIcQcqaidUzpW59t+mQmcjiZ5OR0e90bL9T70x+3+fG5jo3NzenfjdDIlL59IymgPPaSqvaBplSrbOZDuk3S6jNe4Nkq/X7HEhlyNR3mF5Wt9WLvV7rqJ5Lv9cyyJT4qNncW/rnMKfOm6/N513HsAsRERENO2x8EBERUaTY+CAiIqJIDbsxH1GkKQ5yDr+xQb/x4yhi2zbHdYt1e027bXMNbNIJS3UQZOxItqeE6oKUVR8PMmPGDPE8ev3o7wOAxsbGtOcwSava6oKs3un3OQlrbFim7uewSCuY2sj2qrZR/PuQqRQOmfheD2tqrYljPoiIiGjYYeODiIiIIsXGBxEREUVq2I35CCLqWKrf+ddu5YlifIHfsptxRDMvgD6OIIqxEZmKAftNQW0T5w2r7DapmaWcLdL4GT2dOuDMEbBv3z7P5Quy7HhY+V100nNg8juOQzqO2/uksWHS+YOkN5dkIr16WN/bmRrL5/Vzud2TXvOX2KRel54nm2UYTBzzQURERMOOVeOjqakJZ599NiZMmIDJkydjwYIF6Orqcuxz6NAh1NfXY9KkSRg/fjwWLlyI7u7uUAtNREREw1eBzc6bNm1CfX09zj77bHzxxRe4++678Z3vfAe7du3C8ccfD+DLqXLPP/881q5di3g8joaGBlx11VV47bXXMvIBbEjdmZmY5mQeR+oWznY6X5suOJuwgtt0zXSiWL3YjXRP2Kz6K5VHCmfZdOPbdNVL5/A7zd1Mq20zrTITIcYgK59KdRBWPUuhJpswlNcud5tVmaXjZIrNM+z3+1CqZ7djStu9lscm5b5NqEu6R8zj6K9trqueIgEIZ6q0VePjxRdfPKZAkydPRkdHB771rW+ht7cXjz/+OFavXo2LLroIwJfz/0899VRs2bIF55xzTuACExER0fAWaMzH0RbWxIkTAQAdHR04cuQIamtrU/tMnz4dFRUVaG9vH/IY/f39SCaTjh8iIiIauXw3PgYHB9HY2IjzzjsPZ5xxBgAgkUigsLAQJSUljn1LS0uRSCSGPE5TUxPi8XjqZ+rUqX6LRERERMOA76m2N998M9avX4/NmzdjypQpAIDVq1fj+uuvR39/v2Pf2bNnY+7cuXjggQeOOU5/f79j/2QyKTZAzLEAOmlcgI2wlns22Sy5rZOmPLrFBv2m/jXPKcUHvU4zNY8bZLpf1Knh3c4npeT2WndBlvWOIq2/SX8WzeW39am3bsus+51qKx3HZiyLdN2laxnkHvSaztxm+rXfbYBcP/q1bW1tdWxbvHix47XbtfbD5tn3ez9n6ns0yDm9shkH5HfJCLdzDrXdbaqt1ZiPoxoaGvDcc8/h1VdfTTU8AKCsrAyHDx9GT0+Po/eju7sbZWVlQx4rFoshFov5KQYRERENQ1ZhF6UUGhoasG7dOmzYsAFVVVWO7dXV1Rg7diza2tpSf+vq6sL+/ftRU1MTTomJiIhoWLMKu9xyyy1YvXo1nnnmGZxyyimpv8fjcYwbNw7Al+GYF154AS0tLSguLsatt94KAPjHP/7h6RxBMpyGtWpgrq+c63fVzSgy/C1YsMCxraenx/Hab9glU8LKTuh3SmhY4RKpO9W8R/RrMnfuXMe2TIT/3OpDmpabiSnw2Q7hhXnOsMI3Oin0pV8r4NjnWwq7eL3OQ5U3E2zCdF7/vQgSNpRk4jvf7d8DfaVqfZVqIAthl1WrVgE4toKbm5tTsb+HH34Y+fn5WLhwIfr7+zFv3jw8+uijNqchIiKiEcyq8eGlk6SoqAgrV67EypUrfReKiIiIRi6u7UJERESRGvar2trEhL1OQQoy5VESxrgA85xBptBJcXAp3bJJms5mlk9f5dZmWp7fMQVBYslhXS+dTXzW7wqvbu8La4VMXaampNqU1et96PbMZmJpA5tVosM6Tlhp2nVuKbm9Xmub5RNsBJmWK/F7b4U1HsRmPJHXe8tmxWSbcSZHy8hVbYmIiCinsPFBREREkWLjg4iIiCI17MZ82OQ3MHmd6x8k7a3fdLphxRH1tNaAc1yFTd1lIy2wzi0m7Hc8Rq4J63N4zfcAyDFhqTxmjgcztbbOb3r1TI2xsBnfZPOdku44NkvYh5XfRRoD45byX99us0S7yesYiyB5NSQ2deA31bh0HL/XxxRkHJvX58lmvKC5Xtu+ffvSHvdomTjmg4iIiHIKGx9EREQUqWEXdrGRqRUgvZ4zU+myMzW11Gt5pHOa3XPmaqd6mt5MpJgGoknb7rd71+R35dywQjQ2U7P9skmvXllZ6dgmPZeSIKnypX29vi+IKK6zzfeN/gybz7Ntd3y6c2Q7dBoklBHWatN+QztSCE16vs1zmEtj6Nc6E+nV2fNBREREkWLjg4iIiCLFxgcRERFFymphuVxnxu2kGJvfbW68jtWwKWtYy7dLbNKr24yxMGOFXqcYmlNtzViz/rkzNcbDZnl5nc30VZvz+50q7vcecZueGdb09L1796Z+l5Zaj+o665/LLZV1um1BrrPfabhhsVlawZzab14/nf5MS1PngcyMB7N5Lv3+WyK9z+0zSt9pNs+B1/Ehbuneze/usLHng4iIiCLFxgcRERFFathPtbXp6pw2bVrq99bWVsc2m+6xMFYqDNLN6DdToE24xGYqqXTMKKbUhTXVNqxQhs1qq16n17mdI6xVmf1O/TXZZDjVmVlU9W58v6uQAv6n09qEIHQ20yrDWvU3U9Mz050P8D/VNsiqtplIZ2D73nRspq7brARtM+Xb6z3h9u+s16y36crEqbZERESUU9j4ICIiokix8UFERESRGnZjPqRYYTbSJoc1TsBmul1Yq0NKsUGTtAKjVNbm5mbHa7/p1W3GAWUi/bzNqpd+V5kMi1t82iam71VLS4vjtT591i2ebzMFUyeNN8iF74IojhPWvSXdM26rJOv8jl/J1HPp9b1SvQLhPDM2U9elc7gdx88xh9o3yP3MMR9ERESUc9j4ICIiokix8UFERESRGnZjPkw2KXK95rwIK45okuZNh7UMuzR2xCY26HfJePN9nZ2djtczZswY8n3ZEsZS2eb2TGwLs6z6tTTHTfT09KR+N+8Jc/yOfi39xp1tuNWP17i4TXzf5jvE7/MVJN+E9LkkYaXu95ubxu0zZyKXh9/8O4D3MXg2+VPC+o71Wx63MSf6WCxpeYuhcMwHERER5RyrxseqVatw1llnobi4GMXFxaipqcH69etT2w8dOoT6+npMmjQJ48ePx8KFC9Hd3R16oYmIiGj4sgq7PPvssxgzZgxOPvlkKKXwxBNP4MEHH8TOnTtx+umn4+abb8bzzz+PlpYWxONxNDQ0ID8/H6+99prnArmFXcJKpW3TPWfyes6wpodKx7UJNQUpj/5eswtO6m6Oojs+GzIxrdJmqpvZLbtgwYLU73roxI3NlGG/03Dd6kPq3rU5jk5KA27WTyamZ9pMgQ/ruQzyPmmqfban2gZJZxDWqrZ+yyOVzeT3WvoN/7m9T58+v3jxYrEMQ53HLexSIG41XHbZZY7Xy5Ytw6pVq7BlyxZMmTIFjz/+OFavXo2LLroIwJcx4lNPPRVbtmzBOeecY3MqIiIiGqF8j/kYGBjAmjVr8Omnn6KmpgYdHR04cuQIamtrU/tMnz4dFRUVaG9vT3uc/v5+JJNJxw8RERGNXNaNjzfffBPjx49HLBbDTTfdhHXr1uG0005DIpFAYWHhMaPnS0tLkUgk0h6vqakJ8Xg89TN16lTrD0FERETDh1XYBQBOOeUUdHZ2ore3F3/5y19QV1eHTZs2+S7A0qVLsWTJktTrZDIpNkDMeO3GjRtTv8+dO9exzYyhtba2pj2OHnd2i9fqr/VjDlWGdMzYts2S0plI1WzuJ417sVk22pxqq8eWpaXWzTT60lQv/R4AnOMfbOKsUtzX5jr7jeHbpDuWxh1J9WGyKatZB9Jx9eteVVXl+RxBSNdSquewxmbp3ylux9SvkfmM2HwXeC2b272l30/m/SPRp1sDEP8t8Dq2x2QzLsmkj2MI6143t+n/4Za+G/2OmXJjk/pd+g4x99XHedhcZ6+sGx+FhYU46aSTAADV1dXYvn07HnnkEVx99dU4fPgwenp6HBeju7sbZWVlaY8Xi8UQi8XsS05ERETDUuA8H4ODg+jv70d1dTXGjh2Ltra21Lauri7s378fNTU1QU9DREREI4RVz8fSpUsxf/58VFRUoK+vD6tXr8Yrr7yCl156CfF4HDfccAOWLFmCiRMnori4GLfeeitqamo404WIiIhSrBofBw8exHXXXYcPP/wQ8XgcZ511Fl566SV8+9vfBvBlTC8/Px8LFy5Ef38/5s2bh0cffTTUApvxUT126BZT098rxRxtYnNm3Nvr8uDm+cPKAyCNJbGJdUtzxW3KY+YPMAck+6WXwbwG0rX1e93Nc/iN35rntymPdE/oxzHLKsXFpbE15jbz2ZNIeSMkQerH6zabnA4mfTyEWc/6a7cxDWGNw/H6PrfPLN0/+jbzczU2Njpe+x0LYDNeRR9zIdWjDfMc0nGl+jGfGen7VyJdr7DGrrgdZ8WKFanfzTGSkY/5ePzxx8XtRUVFWLlyJVauXBmoUERERDRycW0XIiIiipT1bJdco3crSSEZwDldSO9SApzdh9L7ALlLTgoD2Uwv05mhCn2KqtnNZ5ZVCvXor91WMK2srEz9btP9bpPqWyKFk6SpgW4hCLP+vJ5fIt1b0nHdzqFfE7NepWvpNm3Zb3mkffXUzDbM40jPt3mv61MDpfqRzgE47xnzWnrt5repuyCkbn0pPCF9F9iUze/z7RZe01+b339ev3+Heq2T7i2Jed/t3bs39buUCsImRGTzb4e5Ta8v6XNJ7zOFFS7XseeDiIiIIsXGBxEREUWKjQ8iIiKKVJ5SSmW7ELpkMpmxNLRERESUWb29vSguLhb3Yc8HERERRYqNDyIiIooUGx9EREQUKTY+iIiIKFJsfBAREVGkcq7xkWOTb4iIiMiCl3/Hc67x0dfXl+0iEBERkU9e/h3PuTwfg4OD+OCDD6CUQkVFBQ4cOOA6X3g0SiaTmDp1KusnDdaPjPUjY/3IWD/pjea6UUqhr68P5eXlyM+X+zZybmG5/Px8TJkyBclkEgBQXFw86i6gDdaPjPUjY/3IWD8y1k96o7VuvCYJzbmwCxEREY1sbHwQERFRpHK28RGLxfDzn/8csVgs20XJSawfGetHxvqRsX5krJ/0WDfe5NyAUyIiIhrZcrbng4iIiEYmNj6IiIgoUmx8EBERUaTY+CAiIqJIsfFBREREkcrZxsfKlStRWVmJoqIizJkzB9u2bct2kSLX1NSEs88+GxMmTMDkyZOxYMECdHV1OfY5dOgQ6uvrMWnSJIwfPx4LFy5Ed3d3lkqcXcuXL0deXh4aGxtTfxvt9fP+++/jBz/4ASZNmoRx48bhzDPPxI4dO1LblVK47777cOKJJ2LcuHGora3F7t27s1ji6AwMDODee+9FVVUVxo0bh69//ev45S9/6VgUazTVz6uvvorLLrsM5eXlyMvLQ2trq2O7l7r4+OOPsWjRIhQXF6OkpAQ33HADPvnkkwg/ReZI9XPkyBHceeedOPPMM3H88cejvLwc1113HT744APHMUZy/VhTOWjNmjWqsLBQ/f73v1f/+te/1I9//GNVUlKiuru7s120SM2bN081Nzert956S3V2dqrvfve7qqKiQn3yySepfW666SY1depU1dbWpnbs2KHOOeccde6552ax1Nmxbds2VVlZqc466yx12223pf4+muvn448/VtOmTVOLFy9WW7duVe+995566aWX1LvvvpvaZ/ny5Soej6vW1lb1+uuvq8svv1xVVVWpzz//PIslj8ayZcvUpEmT1HPPPaf27Nmj1q5dq8aPH68eeeSR1D6jqX5eeOEFdc8996inn35aAVDr1q1zbPdSF5dccon6xje+obZs2aL+/ve/q5NOOklde+21EX+SzJDqp6enR9XW1qqnnnpKvf3226q9vV3Nnj1bVVdXO44xkuvHVk42PmbPnq3q6+tTrwcGBlR5eblqamrKYqmy7+DBgwqA2rRpk1Lqyxt+7Nixau3atal9/v3vfysAqr29PVvFjFxfX586+eST1csvv6wuuOCCVONjtNfPnXfeqc4///y02wcHB1VZWZl68MEHU3/r6elRsVhM/elPf4qiiFl16aWXqh/96EeOv1111VVq0aJFSqnRXT/mP65e6mLXrl0KgNq+fXtqn/Xr16u8vDz1/vvvR1b2KAzVODNt27ZNAVD79u1TSo2u+vEi58Iuhw8fRkdHB2pra1N/y8/PR21tLdrb27NYsuzr7e0FAEycOBEA0NHRgSNHjjjqavr06aioqBhVdVVfX49LL73UUQ8A6+evf/0rZs2ahe9973uYPHkyZs6cid/97nep7Xv27EEikXDUTzwex5w5c0ZF/Zx77rloa2vDO++8AwB4/fXXsXnzZsyfPx8A60fnpS7a29tRUlKCWbNmpfapra1Ffn4+tm7dGnmZs623txd5eXkoKSkBwPox5dyqth999BEGBgZQWlrq+HtpaSnefvvtLJUq+wYHB9HY2IjzzjsPZ5xxBgAgkUigsLAwdXMfVVpaikQikYVSRm/NmjX45z//ie3btx+zbbTXz3vvvYdVq1ZhyZIluPvuu7F9+3b85Cc/QWFhIerq6lJ1MNSzNhrq56677kIymcT06dMxZswYDAwMYNmyZVi0aBEAjPr60Xmpi0QigcmTJzu2FxQUYOLEiaOuvg4dOoQ777wT1157bWplW9aPU841Pmho9fX1eOutt7B58+ZsFyVnHDhwALfddhtefvllFBUVZbs4OWdwcBCzZs3Cr371KwDAzJkz8dZbb+Gxxx5DXV1dlkuXfX/+85/x5JNPYvXq1Tj99NPR2dmJxsZGlJeXs37ItyNHjuD73/8+lFJYtWpVtouTs3Iu7HLCCSdgzJgxx8xI6O7uRllZWZZKlV0NDQ147rnnsHHjRkyZMiX197KyMhw+fBg9PT2O/UdLXXV0dODgwYP45je/iYKCAhQUFGDTpk349a9/jYKCApSWlo7q+jnxxBNx2mmnOf526qmnYv/+/QCQqoPR+qz99Kc/xV133YVrrrkGZ555Jn74wx/i9ttvR1NTEwDWj85LXZSVleHgwYOO7V988QU+/vjjUVNfRxse+/btw8svv5zq9QBYP6aca3wUFhaiuroabW1tqb8NDg6ira0NNTU1WSxZ9JRSaGhowLp167BhwwZUVVU5tldXV2Ps2LGOuurq6sL+/ftHRV1dfPHFePPNN9HZ2Zn6mTVrFhYtWpT6fTTXz3nnnXfM1Ox33nkH06ZNAwBUVVWhrKzMUT/JZBJbt24dFfXz2WefIT/f+RU4ZswYDA4OAmD96LzURU1NDXp6etDR0ZHaZ8OGDRgcHMScOXMiL3PUjjY8du/ejb/97W+YNGmSY/tor59jZHvE61DWrFmjYrGYamlpUbt27VI33nijKikpUYlEIttFi9TNN9+s4vG4euWVV9SHH36Y+vnss89S+9x0002qoqJCbdiwQe3YsUPV1NSompqaLJY6u/TZLkqN7vrZtm2bKigoUMuWLVO7d+9WTz75pDruuOPUH//4x9Q+y5cvVyUlJeqZZ55Rb7zxhrriiitG7FRSU11dnfrqV7+ammr79NNPqxNOOEHdcccdqX1GU/309fWpnTt3qp07dyoA6qGHHlI7d+5MzdbwUheXXHKJmjlzptq6davavHmzOvnkk0fMVFKpfg4fPqwuv/xyNWXKFNXZ2en4vu7v708dYyTXj62cbHwopdRvfvMbVVFRoQoLC9Xs2bPVli1bsl2kyAEY8qe5uTm1z+eff65uueUW9ZWvfEUdd9xx6sorr1Qffvhh9gqdZWbjY7TXz7PPPqvOOOMMFYvF1PTp09Vvf/tbx/bBwUF17733qtLSUhWLxdTFF1+surq6slTaaCWTSXXbbbepiooKVVRUpL72ta+pe+65x/GPxWiqn40bNw75fVNXV6eU8lYX//3vf9W1116rxo8fr4qLi9X111+v+vr6svBpwifVz549e9J+X2/cuDF1jJFcP7bylNLS+RERERFlWM6N+SAiIqKRjY0PIiIiihQbH0RERBQpNj6IiIgoUmx8EBERUaTY+CAiIqJIsfFBREREkWLjg4iIiCLFxgcRERFFio0PIiIiihQbH0RERBSp/wFoQYQXP1ktRwAAAABJRU5ErkJggg==",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 0.0026, norm: 0.0014, param_norm: 935.4885, vb_loss: 0.0008, ce_loss: 0.7653: 19%|█▉ | 298/1563 [00:23<00:50, 25.17it/s]"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVCUlEQVR4nO3df2xV9f3H8VdL29tqube0rPfSwZVuklUHONZCuWL2y7sAMwqj24SwUZXMwIqjNpmIDvbHxkq2ZKKLYrZlsGUyHIngZFPCCsJYaks76qyVirGRRrgXHfbegtJi72d/fOP9cvlRekv7ube9z0fySeg5h3vffX9o74vPOefeNGOMEQAAgCXpiS4AAACkFsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsGrYwseTTz6pyZMnKzs7W+Xl5WpsbByupwIAACNI2nB8tsuzzz6rZcuW6emnn1Z5ebk2bdqkHTt2qL29XYWFhf3+3UgkohMnTmjs2LFKS0sb6tIAAMAwMMaou7tbRUVFSk+/ytqGGQazZs0yVVVV0a/7+vpMUVGRqa2tverf7ezsNJIYDAaDwWCMwNHZ2XnV1/ohP+3S29ur5uZm+f3+6Lb09HT5/X7V19dfcnxPT4/C4XB0GD5kFwCAEWvs2LFXPWbIw8f777+vvr4+ud3umO1ut1uBQOCS42tra+VyuaLD6/UOdUkAAMCSgVwykfC7XdauXatQKBQdnZ2diS4JAAAMo4yhfsDx48drzJgxCgaDMduDwaA8Hs8lxzscDjkcjqEuAwAAJKkhX/nIyspSaWmp6urqotsikYjq6urk8/mG+ukAAMAIM+QrH5JUU1OjyspKlZWVadasWdq0aZPOnj2re++9dzieDgAAjCDDEj7uvvtuvffee1q/fr0CgYC+8IUv6KWXXrrkIlQAAJB6huVNxq5FOByWy+VKdBkAAGAQQqGQnE5nv8ck/G4XAACQWggfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwKqMRBeA+H3wwQfD/hzjxo0b9udAfIZj3pnnxLPx83wx5t0+5jkWKx8AAMAqwgcAALAq7vBx8OBB3XnnnSoqKlJaWpp27doVs98Yo/Xr12vChAnKycmR3+/XsWPHhqpeAAAwwsV9zcfZs2d1yy236L777tOiRYsu2f+LX/xCTzzxhP7whz+ouLhY69at09y5c9XW1qbs7OwhKXo0SsT5QAyPkTSXF9eazOeIk81Immfm9dow10Mv7vAxf/58zZ8//7L7jDHatGmTfvzjH2vBggWSpD/+8Y9yu93atWuXFi9efG3VAgCAEW9Ir/no6OhQIBCQ3++PbnO5XCovL1d9ff1l/05PT4/C4XDMAAAAo9eQho9AICBJcrvdMdvdbnd038Vqa2vlcrmiY9KkSUNZEgAASDIJf5+PtWvXqqamJvp1OBwe0QEk2c4NjpTzfyNNoueZebWDeU4NiZ5nKfXmekhXPjwejyQpGAzGbA8Gg9F9F3M4HHI6nTEDAACMXkMaPoqLi+XxeFRXVxfdFg6H1dDQIJ/PN5RPBQAARqi4T7ucOXNGb731VvTrjo4OtbS0KD8/X16vV9XV1frZz36mKVOmRG+1LSoq0sKFC4eyblxBqi3d2ZIMy7IAMFrEHT6ampr01a9+Nfr1J9drVFZWauvWrXrooYd09uxZ3X///erq6tJtt92ml156iff4AAAAkqQ0Y4xJdBEXCofDcrlciS5j0BL9P2RWPoZHouf1YsyzHYmed+bZjkTPszS65joUCl31+k0+2wUAAFjFygcAABgyrHwAAICkQ/gAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiVkegCAGC0++CDD664b9y4cRYrAZIDKx8AAMAqwgcAALCK8AEAAKzimg8AGGZc1wHEYuUDAABYFVf4qK2t1cyZMzV27FgVFhZq4cKFam9vjznm3LlzqqqqUkFBgXJzc1VRUaFgMDikRQMAgJErzRhjBnrwvHnztHjxYs2cOVMff/yxHnnkEbW2tqqtrU3XX3+9JGnlypX629/+pq1bt8rlcmnVqlVKT0/Xv/71rwE9RzgclsvlGtx3k4L6u4XvYiz9jl4X/jtgnkc25hIjXSgUktPp7PeYuMLHxd577z0VFhbqwIED+tKXvqRQKKRPfepT2rZtm771rW9Jko4ePaqbbrpJ9fX1mj179lUfk/ARH8IHJF6wRhPmEiPdQMLHNV3zEQqFJEn5+fmSpObmZp0/f15+vz96TElJibxer+rr6y/7GD09PQqHwzEDAACMXoMOH5FIRNXV1ZozZ46mTp0qSQoEAsrKylJeXl7MsW63W4FA4LKPU1tbK5fLFR2TJk0abEkAAGAEGPSttlVVVWptbdWhQ4euqYC1a9eqpqYm+nU4HCaAAJcRzyk2jA4XzzmnYUaPVH/L/UGFj1WrVmn37t06ePCgJk6cGN3u8XjU29urrq6umNWPYDAoj8dz2cdyOBxyOByDKQMAAIxAcZ12McZo1apV2rlzp/bt26fi4uKY/aWlpcrMzFRdXV10W3t7u44fPy6fzzc0FQMAgBEtrpWPqqoqbdu2Tc8//7zGjh0bvY7D5XIpJydHLpdLy5cvV01NjfLz8+V0OvXAAw/I5/MN6E4XAP+PO5mA0YPTprHiCh+bN2+WJH3lK1+J2b5lyxbdc889kqTHHntM6enpqqioUE9Pj+bOnaunnnpqSIoFAAAj3zW9z8dw4H0+4sP/jkcv5jY1pfqFiKNVKv08D/v7fAAAAMSLT7VNUkN1fnCkJ+jRhvO+qYl5H72Y28Fh5QMAAFhF+AAAAFYRPgAAgFVc85Egw3WekGs8ks9wzDXznHyY59HLxnUdqTbXrHwAAACrCB8AAMAqTrsMI06tpA6W3FMD85wamOfhx8oHAACwivABAACsInwAAACruObjGnFdR2pgnlMH5/tTA/OcWKx8AAAAqwgfAADAKk67DMJQLNexPJd8OLWSGpjn1MGpleTFygcAALCK8AEAAKwifAAAAKu45mMYcW4w+Q3VOWHmOrkxz6mBeR45WPkAAABWET4AAIBVhA8AAGBVmjHGJLqIC4XDYblcrkSXAQAABiEUCsnpdPZ7DCsfAADAqrjCx+bNmzV9+nQ5nU45nU75fD69+OKL0f3nzp1TVVWVCgoKlJubq4qKCgWDwSEvGgAAjFxxhY+JEydq48aNam5uVlNTk772ta9pwYIFev311yVJDz74oF544QXt2LFDBw4c0IkTJ7Ro0aJhKRwAAIxQ5hqNGzfO/O53vzNdXV0mMzPT7NixI7rvjTfeMJJMfX39gB8vFAoZSQwGg8FgMEbgCIVCV32tH/Q1H319fdq+fbvOnj0rn8+n5uZmnT9/Xn6/P3pMSUmJvF6v6uvrr/g4PT09CofDMQMAAIxecYeP1157Tbm5uXI4HFqxYoV27typm2++WYFAQFlZWcrLy4s53u12KxAIXPHxamtr5XK5omPSpElxfxMAAGDkiDt8fO5zn1NLS4saGhq0cuVKVVZWqq2tbdAFrF27VqFQKDo6OzsH/VgAACD5xf3ZLllZWbrxxhslSaWlpTp8+LAef/xx3X333ert7VVXV1fM6kcwGJTH47ni4zkcDjkcjvgrBwAAI9I1v89HJBJRT0+PSktLlZmZqbq6uui+9vZ2HT9+XD6f71qfBgAAjBJxrXysXbtW8+fPl9frVXd3t7Zt26aXX35Ze/bskcvl0vLly1VTU6P8/Hw5nU498MAD8vl8mj179nDVDwAARpi4wsepU6e0bNkynTx5Ui6XS9OnT9eePXv09a9/XZL02GOPKT09XRUVFerp6dHcuXP11FNPDUvhAABgZOKzXQAAwJDhs10AAEDSIXwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqqQLH0n2UTMAACAOA3kdT7rw0d3dnegSAADAIA3kdTzpPtU2EonoxIkTMsbI6/Wqs7Pzqp+Ol4rC4bAmTZpEf66A/vSP/vSP/vSP/lxZKvfGGKPu7m4VFRUpPb3/tY0MSzUNWHp6uiZOnKhwOCxJcjqdKTeB8aA//aM//aM//aM//aM/V5aqvXG5XAM6LulOuwAAgNGN8AEAAKxK2vDhcDj0k5/8RA6HI9GlJCX60z/60z/60z/60z/6c2X0ZmCS7oJTAAAwuiXtygcAABidCB8AAMAqwgcAALCK8AEAAKwifAAAAKuSNnw8+eSTmjx5srKzs1VeXq7GxsZEl2RdbW2tZs6cqbFjx6qwsFALFy5Ue3t7zDHnzp1TVVWVCgoKlJubq4qKCgWDwQRVnFgbN25UWlqaqquro9tSvT/vvvuuvvvd76qgoEA5OTmaNm2ampqaovuNMVq/fr0mTJignJwc+f1+HTt2LIEV29PX16d169apuLhYOTk5+uxnP6uf/vSnMR+KlUr9OXjwoO68804VFRUpLS1Nu3btitk/kF6cPn1aS5culdPpVF5enpYvX64zZ85Y/C6GT3/9OX/+vNasWaNp06bp+uuvV1FRkZYtW6YTJ07EPMZo7k/cTBLavn27ycrKMr///e/N66+/br7//e+bvLw8EwwGE12aVXPnzjVbtmwxra2tpqWlxXzjG98wXq/XnDlzJnrMihUrzKRJk0xdXZ1pamoys2fPNrfeemsCq06MxsZGM3nyZDN9+nSzevXq6PZU7s/p06fNDTfcYO655x7T0NBg3n77bbNnzx7z1ltvRY/ZuHGjcblcZteuXebVV181d911lykuLjYfffRRAiu3Y8OGDaagoMDs3r3bdHR0mB07dpjc3Fzz+OOPR49Jpf78/e9/N48++qh57rnnjCSzc+fOmP0D6cW8efPMLbfcYl555RXzz3/+09x4441myZIllr+T4dFff7q6uozf7zfPPvusOXr0qKmvrzezZs0ypaWlMY8xmvsTr6QMH7NmzTJVVVXRr/v6+kxRUZGpra1NYFWJd+rUKSPJHDhwwBjzf//gMzMzzY4dO6LHvPHGG0aSqa+vT1SZ1nV3d5spU6aYvXv3mi9/+cvR8JHq/VmzZo257bbbrrg/EokYj8djfvnLX0a3dXV1GYfDYf785z/bKDGh7rjjDnPffffFbFu0aJFZunSpMSa1+3Pxi+tAetHW1mYkmcOHD0ePefHFF01aWpp59913rdVuw+XC2cUaGxuNJPPOO+8YY1KrPwORdKddent71dzcLL/fH92Wnp4uv9+v+vr6BFaWeKFQSJKUn58vSWpubtb58+djelVSUiKv15tSvaqqqtIdd9wR0weJ/vz1r39VWVmZvv3tb6uwsFAzZszQb3/72+j+jo4OBQKBmP64XC6Vl5enRH9uvfVW1dXV6c0335Qkvfrqqzp06JDmz58vif5caCC9qK+vV15ensrKyqLH+P1+paenq6GhwXrNiRYKhZSWlqa8vDxJ9OdiSfeptu+//776+vrkdrtjtrvdbh09ejRBVSVeJBJRdXW15syZo6lTp0qSAoGAsrKyov+4P+F2uxUIBBJQpX3bt2/Xv//9bx0+fPiSfanen7ffflubN29WTU2NHnnkER0+fFg//OEPlZWVpcrKymgPLvezlgr9efjhhxUOh1VSUqIxY8aor69PGzZs0NKlSyUp5ftzoYH0IhAIqLCwMGZ/RkaG8vPzU65f586d05o1a7RkyZLoJ9vSn1hJFz5weVVVVWptbdWhQ4cSXUrS6Ozs1OrVq7V3715lZ2cnupykE4lEVFZWpp///OeSpBkzZqi1tVVPP/20KisrE1xd4v3lL3/RM888o23btunzn/+8WlpaVF1draKiIvqDQTt//ry+853vyBijzZs3J7qcpJV0p13Gjx+vMWPGXHJHQjAYlMfjSVBVibVq1Srt3r1b+/fv18SJE6PbPR6Pent71dXVFXN8qvSqublZp06d0he/+EVlZGQoIyNDBw4c0BNPPKGMjAy53e6U7s+ECRN08803x2y76aabdPz4cUmK9iBVf9Z+9KMf6eGHH9bixYs1bdo0fe9739ODDz6o2tpaSfTnQgPphcfj0alTp2L2f/zxxzp9+nTK9OuT4PHOO+9o79690VUPif5cLOnCR1ZWlkpLS1VXVxfdFolEVFdXJ5/Pl8DK7DPGaNWqVdq5c6f27dun4uLimP2lpaXKzMyM6VV7e7uOHz+eEr26/fbb9dprr6mlpSU6ysrKtHTp0uifU7k/c+bMueTW7DfffFM33HCDJKm4uFgejyemP+FwWA0NDSnRnw8//FDp6bG/AseMGaNIJCKJ/lxoIL3w+Xzq6upSc3Nz9Jh9+/YpEomovLzces22fRI8jh07pn/84x8qKCiI2Z/q/blEoq94vZzt27cbh8Nhtm7datra2sz9999v8vLyTCAQSHRpVq1cudK4XC7z8ssvm5MnT0bHhx9+GD1mxYoVxuv1mn379pmmpibj8/mMz+dLYNWJdeHdLsakdn8aGxtNRkaG2bBhgzl27Jh55plnzHXXXWf+9Kc/RY/ZuHGjycvLM88//7z5z3/+YxYsWDBqbyW9WGVlpfn0pz8dvdX2ueeeM+PHjzcPPfRQ9JhU6k93d7c5cuSIOXLkiJFkfvWrX5kjR45E79YYSC/mzZtnZsyYYRoaGsyhQ4fMlClTRs2tpP31p7e319x1111m4sSJpqWlJeb3dU9PT/QxRnN/4pWU4cMYY379618br9drsrKyzKxZs8wrr7yS6JKsk3TZsWXLlugxH330kfnBD35gxo0bZ6677jrzzW9+05w8eTJxRSfYxeEj1fvzwgsvmKlTpxqHw2FKSkrMb37zm5j9kUjErFu3zrjdbuNwOMztt99u2tvbE1StXeFw2Kxevdp4vV6TnZ1tPvOZz5hHH3005sUilfqzf//+y/6+qaysNMYMrBf//e9/zZIlS0xubq5xOp3m3nvvNd3d3Qn4boZef/3p6Oi44u/r/fv3Rx9jNPcnXmnGXPB2fgAAAMMs6a75AAAAoxvhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFb9D9oAs+8UV1pFAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 0.0011, norm: 0.0024, param_norm: 936.5659, vb_loss: 0.0005, ce_loss: 0.4201: 38%|███▊ | 598/1563 [00:44<00:38, 24.82it/s]"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWY0lEQVR4nO3df2yV5f3/8VdL6Slazikt6Tl0cKSbZOiAjRUoR8x0ehZgRmGQTQgb1ZEZWHHUJhPRwRI3VpYlE7cgZr9gy2Q4EsHJBoQVhGFKgY46EK0YmTTCOcyxnlNQWuy5Pn98v97pKVDOKe19nx/PR3Il7X1fPefd99X79J3run/kGGOMAAAAbJLrdAAAACC7UHwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbDVjxsW7dOo0ePVoFBQWqrKzUoUOHBuqtAABAGskZiGe7vPjii1q4cKGef/55VVZWau3atdqyZYtaWlpUWlra68/GYjGdOXNGQ4cOVU5OTn+HBgAABoAxRu3t7SorK1Nu7nXmNswAmDJliqmurra+7+rqMmVlZaauru66P9va2mok0Wg0Go1GS8PW2tp63f/1/b7s0tnZqaamJgWDQWtbbm6ugsGgGhoarujf0dGhaDRqNcNDdgEASFtDhw69bp9+Lz4++OADdXV1yev1xm33er0KhUJX9K+rq5PH47Ga3+/v75AAAIBNEjllwvGrXVasWKFIJGK11tZWp0MCAAADKK+/X3D48OEaNGiQwuFw3PZwOCyfz3dFf5fLJZfL1d9hAACAFNXvMx/5+fmqqKhQfX29tS0Wi6m+vl6BQKC/3w4AAKSZfp/5kKTa2lpVVVVp0qRJmjJlitauXauLFy/q4YcfHoi3AwAAaWRAio8HH3xQ//nPf7Rq1SqFQiF94Qtf0M6dO684CRUAAGSfAbnJ2I2IRqPyeDxOhwEAAPogEonI7Xb32sfxq10AAEB2ofgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ynM6ACTvf//7X8J9a2pqrK/Xrl2b8M8NGzYsiYgwEOwY554Yd/v1dZylvo8142w/u8Y5XcaWmQ8AAGArig8AAGCrpIuP/fv36/7771dZWZlycnK0bdu2uP3GGK1atUojRozQkCFDFAwGdfLkyf6KFwAApLkcY4xJ5gd27Nih1157TRUVFZozZ462bt2q2bNnW/t/+tOfqq6uTr///e9VXl6ulStX6tixYzpx4oQKCgqu+/rRaFQejyfpXyRVdF/X67n2lsyaX6pJl3VEu1RVVV1z342cc+E0xjlxHM+Zo+dY9te5NU5IhbGNRCJyu9299kn6hNOZM2dq5syZV91njNHatWv1gx/8QLNmzZIk/eEPf5DX69W2bds0b968ZN8OAABkmH495+PUqVMKhUIKBoPWNo/Ho8rKSjU0NFz1Zzo6OhSNRuMaAADIXP1afIRCIUmS1+uN2+71eq19PdXV1cnj8Vht1KhR/RkSAABIMY7f52PFihWqra21vo9GoylfgCS61uvEmnAqrPdlilRe02ecB07383mcXutnnAdOqp23lehYp/LnUjL6debD5/NJksLhcNz2cDhs7evJ5XLJ7XbHNQAAkLn6tfgoLy+Xz+dTfX29tS0ajaqxsVGBQKA/3woAAKSppJddLly4oHfeecf6/tSpU2publZxcbH8fr9qamr04x//WGPGjLEutS0rK4u7HBfX130KLlOm2XAlxjn1DMSUeyZddp8p0nWcM2UpLuni48iRI/ryl79sff/J+RpVVVXauHGjHn/8cV28eFGPPPKI2tradOedd2rnzp0J3eMDAABkvqSLj7vvvlu93ZcsJydHTz/9tJ5++ukbCgwAAGQmnu0CAABslfTt1QdaJt1ePRmZso6XLRjnzNXb2CZ6jg7jnN6SOb4Z6yslcnt1Zj4AAICtKD4AAICtKD4AAICtHL+9eqbpbf2Pa/szB+OcuRJdw+feHZmL43vgMfMBAABsRfEBAABsxbJLiug5lcflW5mJcc4OjHP26D7WjHPimPkAAAC2ovgAAAC2ovgAAAC24pwPG/H49Ozg9CWYnG/gDI7v7OD08Z0pmPkAAAC2ovgAAAC2YtklRfU2lTd79mzr63379tkQDQbKQDw9k2WW5PTXMlVvY8lTUrMD45w4Zj4AAICtKD4AAICtKD4AAICtcowxxukguotGo/J4PE6HYbuBulwr29cVU01/jTPj2n/6a53ejksuGffUMxDjnu7jHIlE5Ha7e+3DzAcAALAVxQcAALAVxQcAALAV53ykgRtZU+SeIOmDewSkhkTHYaDGoLdHtHPr/IExUHnt6zGd7uPMOR8AACDlJFV81NXVafLkyRo6dKhKS0s1e/ZstbS0xPW5dOmSqqurVVJSosLCQs2dO1fhcLhfgwYAAOkrqWWXGTNmaN68eZo8ebI+/vhjPfnkkzp+/LhOnDihm2++WZK0ZMkS/fWvf9XGjRvl8Xi0dOlS5ebm6rXXXkvoPVh2SQ5T9dmjt7FmbAeG0zm/3vHNuPcPJ/KcyZ/diSy7JPVsl507d8Z9v3HjRpWWlqqpqUlf+tKXFIlE9Nvf/labNm3SPffcI0nasGGDbrvtNh08eFBTp05N8lcAAACZ5obO+YhEIpKk4uJiSVJTU5MuX76sYDBo9Rk7dqz8fr8aGhqu+hodHR2KRqNxDQAAZK4+Fx+xWEw1NTWaNm2axo0bJ0kKhULKz89XUVFRXF+v16tQKHTV16mrq5PH47HaqFGj+hoSAABIA0ktu3RXXV2t48eP68CBAzcUwIoVK1RbW2t9H41GKUCAJPV2eSb6zolc2nGbdsTjmLFfn4qPpUuXavv27dq/f79Gjhxpbff5fOrs7FRbW1vc7Ec4HJbP57vqa7lcLrlcrr6EAQAA0lBSyy7GGC1dulRbt27Vnj17VF5eHre/oqJCgwcPVn19vbWtpaVFp0+fViAQ6J+IAQBAWktq5qO6ulqbNm3Syy+/rKFDh1rncXg8Hg0ZMkQej0eLFi1SbW2tiouL5Xa79eijjyoQCHClCxzBpYpwwvXuUMnSSv/g+E5fSRUf69evlyTdfffdcds3bNighx56SJL0zDPPKDc3V3PnzlVHR4emT5+u5557rl+CBQAA6S+p4iOR+5EVFBRo3bp1WrduXZ+DAgAAmYtnuwAAAFv1+VLbdJfqTw286667rK+3bdvmXCBpLhXW2quqqqyv//3vf1+zH+Oc3nr72+p+PN+IVPuccloqXArdPYae48wxfW3MfAAAAFtRfAAAAFtRfAAAAFvlmEQuYbFRNBqVx+NxOowB0X3tv6e1a9f2y3uwJuw8xjk79Bzn7ufzDNRaP+PuDI7p5EQiEbnd7l77MPMBAABsRfEBAABslbWX2tqh52VXfZ2ey6TpuEw1EJfwMu6pp/v0O9PtmavnMkt/jDXjHI+ZDwAAYCuKDwAAYCuKDwAAYCvO+egDLrvKDoxzdmCcswPjnFqY+QAAALai+AAAALZi2eUq7HjyKdNzqYGxzkxOPL2YcXYex3P6YOYDAADYiuIDAADYiuIDAADYinM+/j/WCrODE+cCABgYHM/pi5kPAABgK4oPAABgK4oPAABgqxxjjHE6iO6i0ag8Ho/TYQAAgD6IRCJyu9299mHmAwAA2Cqp4mP9+vWaMGGC3G633G63AoGAduzYYe2/dOmSqqurVVJSosLCQs2dO1fhcLjfgwYAAOkrqeJj5MiRWrNmjZqamnTkyBHdc889mjVrlt544w1J0mOPPaZXXnlFW7Zs0b59+3TmzBnNmTNnQAIHAABpytygYcOGmd/85jemra3NDB482GzZssXa9+abbxpJpqGhIeHXi0QiRhKNRqPRaLQ0bJFI5Lr/6/t8zkdXV5c2b96sixcvKhAIqKmpSZcvX1YwGLT6jB07Vn6/Xw0NDdd8nY6ODkWj0bgGAAAyV9LFx7Fjx1RYWCiXy6XFixdr69atuv322xUKhZSfn6+ioqK4/l6vV6FQ6JqvV1dXJ4/HY7VRo0Yl/UsAAID0kXTx8dnPflbNzc1qbGzUkiVLVFVVpRMnTvQ5gBUrVigSiVittbW1z68FAABSX9LPdsnPz9ett94qSaqoqNDhw4f17LPP6sEHH1RnZ6fa2triZj/C4bB8Pt81X8/lcsnlciUfOQAASEs3fJ+PWCymjo4OVVRUaPDgwaqvr7f2tbS06PTp0woEAjf6NgAAIEMkNfOxYsUKzZw5U36/X+3t7dq0aZNeffVV7dq1Sx6PR4sWLVJtba2Ki4vldrv16KOPKhAIaOrUqQMVPwAASDNJFR/nzp3TwoULdfbsWXk8Hk2YMEG7du3SV77yFUnSM888o9zcXM2dO1cdHR2aPn26nnvuuQEJHAAApCee7QIAAPoNz3YBAAAph+IDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYKuWKjxR71AwAAEhCIv/HU674aG9vdzoEAADQR4n8H0+5p9rGYjGdOXNGxhj5/X61trZe9+l42SgajWrUqFHk5xrIT+/IT+/IT+/Iz7Vlc26MMWpvb1dZWZlyc3uf28izKaaE5ebmauTIkYpGo5Ikt9uddQOYDPLTO/LTO/LTO/LTO/JzbdmaG4/Hk1C/lFt2AQAAmY3iAwAA2Cpliw+Xy6Uf/vCHcrlcToeSkshP78hP78hP78hP78jPtZGbxKTcCacAACCzpezMBwAAyEwUHwAAwFYUHwAAwFYUHwAAwFYUHwAAwFYpW3ysW7dOo0ePVkFBgSorK3Xo0CGnQ7JdXV2dJk+erKFDh6q0tFSzZ89WS0tLXJ9Lly6purpaJSUlKiws1Ny5cxUOhx2K2Flr1qxRTk6OampqrG3Znp/3339f3/zmN1VSUqIhQ4Zo/PjxOnLkiLXfGKNVq1ZpxIgRGjJkiILBoE6ePOlgxPbp6urSypUrVV5eriFDhugzn/mMfvSjH8U9FCub8rN//37df//9KisrU05OjrZt2xa3P5FcnD9/XgsWLJDb7VZRUZEWLVqkCxcu2PhbDJze8nP58mUtX75c48eP180336yysjItXLhQZ86ciXuNTM5P0kwK2rx5s8nPzze/+93vzBtvvGG+853vmKKiIhMOh50OzVbTp083GzZsMMePHzfNzc3mq1/9qvH7/ebChQtWn8WLF5tRo0aZ+vp6c+TIETN16lRzxx13OBi1Mw4dOmRGjx5tJkyYYJYtW2Ztz+b8nD9/3txyyy3moYceMo2Njebdd981u3btMu+8847VZ82aNcbj8Zht27aZ119/3TzwwAOmvLzcfPTRRw5Gbo/Vq1ebkpISs337dnPq1CmzZcsWU1hYaJ599lmrTzbl529/+5t56qmnzEsvvWQkma1bt8btTyQXM2bMMJ///OfNwYMHzT/+8Q9z6623mvnz59v8mwyM3vLT1tZmgsGgefHFF81bb71lGhoazJQpU0xFRUXca2RyfpKVksXHlClTTHV1tfV9V1eXKSsrM3V1dQ5G5bxz584ZSWbfvn3GmP/3Bz948GCzZcsWq8+bb75pJJmGhganwrRde3u7GTNmjNm9e7e56667rOIj2/OzfPlyc+edd15zfywWMz6fz/zsZz+ztrW1tRmXy2X+9Kc/2RGio+677z7z7W9/O27bnDlzzIIFC4wx2Z2fnv9cE8nFiRMnjCRz+PBhq8+OHTtMTk6Oef/9922L3Q5XK856OnTokJFk3nvvPWNMduUnESm37NLZ2ammpiYFg0FrW25uroLBoBoaGhyMzHmRSESSVFxcLElqamrS5cuX43I1duxY+f3+rMpVdXW17rvvvrg8SOTnL3/5iyZNmqSvf/3rKi0t1cSJE/XrX//a2n/q1CmFQqG4/Hg8HlVWVmZFfu644w7V19fr7bffliS9/vrrOnDggGbOnCmJ/HSXSC4aGhpUVFSkSZMmWX2CwaByc3PV2Nhoe8xOi0QiysnJUVFRkSTy01PKPdX2gw8+UFdXl7xeb9x2r9ert956y6GonBeLxVRTU6Np06Zp3LhxkqRQKKT8/Hzrj/sTXq9XoVDIgSjtt3nzZv3zn//U4cOHr9iX7fl59913tX79etXW1urJJ5/U4cOH9b3vfU/5+fmqqqqycnC1Yy0b8vPEE08oGo1q7NixGjRokLq6urR69WotWLBAkrI+P90lkotQKKTS0tK4/Xl5eSouLs66fF26dEnLly/X/PnzrSfbkp94KVd84Oqqq6t1/PhxHThwwOlQUkZra6uWLVum3bt3q6CgwOlwUk4sFtOkSZP0k5/8RJI0ceJEHT9+XM8//7yqqqocjs55f/7zn/XCCy9o06ZN+tznPqfm5mbV1NSorKyM/KDPLl++rG984xsyxmj9+vVOh5OyUm7ZZfjw4Ro0aNAVVySEw2H5fD6HonLW0qVLtX37du3du1cjR460tvt8PnV2dqqtrS2uf7bkqqmpSefOndMXv/hF5eXlKS8vT/v27dMvfvEL5eXlyev1ZnV+RowYodtvvz1u22233abTp09LkpWDbD3Wvv/97+uJJ57QvHnzNH78eH3rW9/SY489prq6Oknkp7tEcuHz+XTu3Lm4/R9//LHOnz+fNfn6pPB47733tHv3bmvWQyI/PaVc8ZGfn6+KigrV19db22KxmOrr6xUIBByMzH7GGC1dulRbt27Vnj17VF5eHre/oqJCgwcPjstVS0uLTp8+nRW5uvfee3Xs2DE1NzdbbdKkSVqwYIH1dTbnZ9q0aVdcmv3222/rlltukSSVl5fL5/PF5ScajaqxsTEr8vPhhx8qNzf+I3DQoEGKxWKSyE93ieQiEAiora1NTU1NVp89e/YoFoupsrLS9pjt9knhcfLkSf39739XSUlJ3P5sz88VnD7j9Wo2b95sXC6X2bhxozlx4oR55JFHTFFRkQmFQk6HZqslS5YYj8djXn31VXP27Fmrffjhh1afxYsXG7/fb/bs2WOOHDliAoGACQQCDkbtrO5XuxiT3fk5dOiQycvLM6tXrzYnT540L7zwgrnpppvMH//4R6vPmjVrTFFRkXn55ZfNv/71LzNr1qyMvZS0p6qqKvOpT33KutT2pZdeMsOHDzePP/641Seb8tPe3m6OHj1qjh49aiSZn//85+bo0aPW1RqJ5GLGjBlm4sSJprGx0Rw4cMCMGTMmYy4l7S0/nZ2d5oEHHjAjR440zc3NcZ/XHR0d1mtkcn6SlZLFhzHG/PKXvzR+v9/k5+ebKVOmmIMHDzodku0kXbVt2LDB6vPRRx+Z7373u2bYsGHmpptuMl/72tfM2bNnnQvaYT2Lj2zPzyuvvGLGjRtnXC6XGTt2rPnVr34Vtz8Wi5mVK1car9drXC6Xuffee01LS4tD0dorGo2aZcuWGb/fbwoKCsynP/1p89RTT8X9s8im/Ozdu/eqnzdVVVXGmMRy8d///tfMnz/fFBYWGrfbbR5++GHT3t7uwG/T/3rLz6lTp675eb13717rNTI5P8nKMabb7fwAAAAGWMqd8wEAADIbxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALDV/wHgUGBrLgpAzQAAAABJRU5ErkJggg==",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 0.0006, norm: 0.0004, param_norm: 938.1175, vb_loss: 0.0003, ce_loss: 0.2727: 58%|█████▊ | 900/1563 [01:05<00:27, 24.27it/s]"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVn0lEQVR4nO3df2yV5f3/8VdL29Nqe05pWc+ho0e6SYYOcKzQcsRsbp4FmFEYbBPCRlUyAyuO2mRidbDEjZVsyUQXxWzLYMtkuCaCk00JKwhjqS3tqLMiFWMjjXAOOtZzCkqLPdf3j292Phwopaec3ufX85FcCb3vq/d59331nL657vu+7gxjjBEAAIBFMuMdAAAASC8UHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFJjVnw8/fTTmjx5snJzc1VVVaXW1taxeikAAJBEMsbi2S7PP/+8VqxYoWeffVZVVVXavHmzGhsb1dXVpZKSkmG/NxQK6eTJkyooKFBGRkasQwMAAGPAGKO+vj6VlpYqM/MqcxtmDFRWVpqamprw14ODg6a0tNQ0NDRc9Xt7enqMJBqNRqPRaEnYenp6rvq3PuanXQYGBtTe3i6v1xvelpmZKa/Xq+bm5sv69/f3KxgMhpvhIbsAACStgoKCq/aJefHx4YcfanBwUE6nM2K70+mUz+e7rH9DQ4McDke4ud3uWIcEAAAsMpJLJuJ+t0t9fb0CgUC49fT0xDskAAAwhrJifcAJEyZo3Lhx8vv9Edv9fr9cLtdl/W02m2w2W6zDAAAACSrmMx85OTmqqKhQU1NTeFsoFFJTU5M8Hk+sXw4AACSZmM98SFJdXZ2qq6s1a9YsVVZWavPmzTp37pzuu+++sXg5AACQRMak+Ljnnnv0wQcfaMOGDfL5fPrCF76gV1555bKLUAEAQPoZk0XGrkUwGJTD4Yh3GAAAYBQCgYDsdvuwfeJ+twsAAEgvFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSWfEOANfmv//977D7x48fP+K+V/o+JIbhxo9xTh0jHeer9R3u+xB/YzHOQ31vomLmAwAAWIriAwAAWCrq4uPgwYO66667VFpaqoyMDO3atStivzFGGzZs0MSJE5WXlyev16vjx4/HKl4AAJDkor7m49y5c7rlllt0//33a/HixZft//nPf66nnnpKv//971VeXq7169dr3rx5Onr0qHJzc2MSdCqI5hxeMrwO/k88cp6K54QTXaKPM2LH6rynwzhHXXwsWLBACxYsGHKfMUabN2/Wj370Iy1cuFCS9Ic//EFOp1O7du3S0qVLry1aAACQ9GJ6zUd3d7d8Pp+8Xm94m8PhUFVVlZqbm4f8nv7+fgWDwYgGAABSV0yLD5/PJ0lyOp0R251OZ3jfpRoaGuRwOMKtrKwsliEBAIAEE/d1Purr61VXVxf+OhgMJlUBkmjn5jifb414jzvjbA3GOT3Ee5yl9BvrmM58uFwuSZLf74/Y7vf7w/suZbPZZLfbIxoAAEhdMS0+ysvL5XK51NTUFN4WDAbV0tIij8cTy5cCAABJKurTLmfPntU777wT/rq7u1sdHR0qKiqS2+1WbW2tfvrTn2rKlCnhW21LS0u1aNGiWMadMK5lGdxojjsWr4H4Y5zTA+OcHhjnkYu6+Ghra9NXvvKV8Nf/u16jurpa27Zt08MPP6xz587pgQceUG9vr2677Ta98sorrPEBAAAkjaL4uP3222WMueL+jIwMPf7443r88cevKTAAAJCaeLYLAACwVIYZbhojDoLBoBwOR7zDiAmWvIYU+XvAOKcuxjk9XPq5zlhfLhAIXPXOVWY+AACApSg+AACApSg+AACApeK+vHoqG6s1QJBcOCecHhjn9MA4xwYzHwAAwFIUHwAAwFKcdrHQxdN1nIIBAKQrZj4AAIClKD4AAIClKD4AAICluOYjTq52Gy63c6UmxhkAmPkAAAAWo/gAAACW4rQLAMQYp9fSx0ifZszvRCRmPgAAgKUoPgAAgKUoPgAAgKUyjDEm3kFcLBgMyuFwxDuMuBvt8uvpfh4x2UQzzoxtahrud4AxT27p+v4OBAKy2+3D9mHmAwAAWIriAwAAWIriAwAAWIp1PhJUNPeLI3ldPM5XG9eRrieA5HK1Ry0geUXz/k43zHwAAABLRVV8NDQ0aPbs2SooKFBJSYkWLVqkrq6uiD7nz59XTU2NiouLlZ+fryVLlsjv98c0aAAAkLyiutV2/vz5Wrp0qWbPnq1PPvlEjz76qDo7O3X06FFdf/31kqTVq1frr3/9q7Zt2yaHw6E1a9YoMzNT//znP0f0Gtxqe21Ywjd1cUtm+uH9nLpS+f08klttr2mdjw8++EAlJSU6cOCAvvSlLykQCOhTn/qUtm/frm9+85uSpGPHjummm25Sc3Oz5syZc9VjUnxcGz6sUlcqf1hhaLyfU1cqv5/HfJ2PQCAgSSoqKpIktbe368KFC/J6veE+U6dOldvtVnNz85DH6O/vVzAYjGgAACB1jbr4CIVCqq2t1dy5czVt2jRJks/nU05OjgoLCyP6Op1O+Xy+IY/T0NAgh8MRbmVlZaMNCQAAJIFR32pbU1Ojzs5OHTp06JoCqK+vV11dXfjrYDBIARIFbs+ExDinkuHe05yGSQ/pMM6jKj7WrFmj3bt36+DBg5o0aVJ4u8vl0sDAgHp7eyNmP/x+v1wu15DHstlsstlsowkDAAAkoahOuxhjtGbNGu3cuVP79u1TeXl5xP6KigplZ2erqakpvK2rq0snTpyQx+OJTcQAACCpRTXzUVNTo+3bt+vFF19UQUFB+DoOh8OhvLw8ORwOrVy5UnV1dSoqKpLdbteDDz4oj8czojtdEHupOF0HpCvez0gVURUfW7ZskSTdfvvtEdu3bt2qe++9V5L0xBNPKDMzU0uWLFF/f7/mzZunZ555JibBAgCA5HdN63yMBdb5iM7VLjjlf0qpY6TPhmDMk1sqr/+A/xPNs16SbdzHfJ0PAACAaPFU2yTA0xDTw7WMc7L9zyjd8Z5OD6Md53R4PzPzAQAALEXxAQAALEXxAQAALMU1HxYabglszg2mjuGWRua6jtTBOKcPPrtjj5kPAABgKYoPAABgKU67xEkqLzCDSCwOlh54T6cHxjk2mPkAAACWovgAAACWovgAAACW4pqPGIvVssmcK0xsjHN6YJzTA+NsPWY+AACApSg+AACApTjtMoThVi4cav9oMD2XGMZi5cJLMdaJjXFOHWO16uyVjonRY+YDAABYiuIDAABYiuIDAABYims+RmC4c4Wc/0sdjHPqYkns9MR7OnEx8wEAACxF8QEAACxF8QEAACyVYYwx8Q7iYsFgUA6HI95hAACAUQgEArLb7cP2YeYDAABYKqriY8uWLZoxY4bsdrvsdrs8Ho9efvnl8P7z58+rpqZGxcXFys/P15IlS+T3+2MeNAAASF5RFR+TJk3Spk2b1N7erra2Nn31q1/VwoUL9eabb0qSHnroIb300ktqbGzUgQMHdPLkSS1evHhMAgcAAEnKXKPx48eb3/72t6a3t9dkZ2ebxsbG8L633nrLSDLNzc0jPl4gEDCSaDQajUajJWELBAJX/Vs/6ms+BgcHtWPHDp07d04ej0ft7e26cOGCvF5vuM/UqVPldrvV3Nx8xeP09/crGAxGNAAAkLqiLj7eeOMN5efny2azadWqVdq5c6duvvlm+Xw+5eTkqLCwMKK/0+mUz+e74vEaGhrkcDjCraysLOofAgAAJI+oi4/Pfe5z6ujoUEtLi1avXq3q6modPXp01AHU19crEAiEW09Pz6iPBQAAEl/Uz3bJycnRjTfeKEmqqKjQ4cOH9eSTT+qee+7RwMCAent7I2Y//H6/XC7XFY9ns9lks9mijxwAACSla17nIxQKqb+/XxUVFcrOzlZTU1N4X1dXl06cOCGPx3OtLwMAAFJEVDMf9fX1WrBggdxut/r6+rR9+3a9+uqr2rNnjxwOh1auXKm6ujoVFRXJbrfrwQcflMfj0Zw5c8YqfgAAkGSiKj5Onz6tFStW6NSpU3I4HJoxY4b27Nmjr33ta5KkJ554QpmZmVqyZIn6+/s1b948PfPMM2MSOAAASE482wUAAMQMz3YBAAAJh+IDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYKuGKjwR71AwAAIjCSP6OJ1zx0dfXF+8QAADAKI3k73jCPdU2FArp5MmTMsbI7Xarp6fnqk/HS0fBYFBlZWXk5wrIz/DIz/DIz/DIz5Wlc26MMerr61NpaakyM4ef28iyKKYRy8zM1KRJkxQMBiVJdrs97QYwGuRneORneORneORneOTnytI1Nw6HY0T9Eu60CwAASG0UHwAAwFIJW3zYbDb9+Mc/ls1mi3coCYn8DI/8DI/8DI/8DI/8XBm5GZmEu+AUAACktoSd+QAAAKmJ4gMAAFiK4gMAAFiK4gMAAFiK4gMAAFgqYYuPp59+WpMnT1Zubq6qqqrU2toa75As19DQoNmzZ6ugoEAlJSVatGiRurq6IvqcP39eNTU1Ki4uVn5+vpYsWSK/3x+niONr06ZNysjIUG1tbXhbuufn/fff13e+8x0VFxcrLy9P06dPV1tbW3i/MUYbNmzQxIkTlZeXJ6/Xq+PHj8cxYusMDg5q/fr1Ki8vV15enj772c/qJz/5ScRDsdIpPwcPHtRdd92l0tJSZWRkaNeuXRH7R5KLM2fOaPny5bLb7SosLNTKlSt19uxZC3+KsTNcfi5cuKB169Zp+vTpuv7661VaWqoVK1bo5MmTEcdI5fxEzSSgHTt2mJycHPO73/3OvPnmm+Z73/ueKSwsNH6/P96hWWrevHlm69atprOz03R0dJivf/3rxu12m7Nnz4b7rFq1ypSVlZmmpibT1tZm5syZY2699dY4Rh0fra2tZvLkyWbGjBlm7dq14e3pnJ8zZ86YG264wdx7772mpaXFvPvuu2bPnj3mnXfeCffZtGmTcTgcZteuXeb11183d999tykvLzcff/xxHCO3xsaNG01xcbHZvXu36e7uNo2NjSY/P988+eST4T7plJ+//e1v5rHHHjMvvPCCkWR27twZsX8kuZg/f7655ZZbzGuvvWb+8Y9/mBtvvNEsW7bM4p9kbAyXn97eXuP1es3zzz9vjh07Zpqbm01lZaWpqKiIOEYq5ydaCVl8VFZWmpqamvDXg4ODprS01DQ0NMQxqvg7ffq0kWQOHDhgjPn/v/DZ2dmmsbEx3Oett94ykkxzc3O8wrRcX1+fmTJlitm7d6/58pe/HC4+0j0/69atM7fddtsV94dCIeNyucwvfvGL8Lbe3l5js9nMn/70JytCjKs777zT3H///RHbFi9ebJYvX26MSe/8XPrHdSS5OHr0qJFkDh8+HO7z8ssvm4yMDPP+++9bFrsVhirOLtXa2mokmffee88Yk175GYmEO+0yMDCg9vZ2eb3e8LbMzEx5vV41NzfHMbL4CwQCkqSioiJJUnt7uy5cuBCRq6lTp8rtdqdVrmpqanTnnXdG5EEiP3/5y180a9Ysfetb31JJSYlmzpyp3/zmN+H93d3d8vl8EflxOByqqqpKi/zceuutampq0ttvvy1Jev3113Xo0CEtWLBAEvm52Ehy0dzcrMLCQs2aNSvcx+v1KjMzUy0tLZbHHG+BQEAZGRkqLCyURH4ulXBPtf3www81ODgop9MZsd3pdOrYsWNxiir+QqGQamtrNXfuXE2bNk2S5PP5lJOTE/7l/h+n0ymfzxeHKK23Y8cO/etf/9Lhw4cv25fu+Xn33Xe1ZcsW1dXV6dFHH9Xhw4f1gx/8QDk5Oaqurg7nYKj3Wjrk55FHHlEwGNTUqVM1btw4DQ4OauPGjVq+fLkkpX1+LjaSXPh8PpWUlETsz8rKUlFRUdrl6/z581q3bp2WLVsWfrIt+YmUcMUHhlZTU6POzk4dOnQo3qEkjJ6eHq1du1Z79+5Vbm5uvMNJOKFQSLNmzdLPfvYzSdLMmTPV2dmpZ599VtXV1XGOLv7+/Oc/67nnntP27dv1+c9/Xh0dHaqtrVVpaSn5wahduHBB3/72t2WM0ZYtW+IdTsJKuNMuEyZM0Lhx4y67I8Hv98vlcsUpqvhas2aNdu/erf3792vSpEnh7S6XSwMDA+rt7Y3ony65am9v1+nTp/XFL35RWVlZysrK0oEDB/TUU08pKytLTqczrfMzceJE3XzzzRHbbrrpJp04cUKSwjlI1/faD3/4Qz3yyCNaunSppk+fru9+97t66KGH1NDQIIn8XGwkuXC5XDp9+nTE/k8++URnzpxJm3z9r/B47733tHfv3vCsh0R+LpVwxUdOTo4qKirU1NQU3hYKhdTU1CSPxxPHyKxnjNGaNWu0c+dO7du3T+Xl5RH7KyoqlJ2dHZGrrq4unThxIi1ydccdd+iNN95QR0dHuM2aNUvLly8P/zud8zN37tzLbs1+++23dcMNN0iSysvL5XK5IvITDAbV0tKSFvn56KOPlJkZ+RE4btw4hUIhSeTnYiPJhcfjUW9vr9rb28N99u3bp1AopKqqKstjttr/Co/jx4/r73//u4qLiyP2p3t+LhPvK16HsmPHDmOz2cy2bdvM0aNHzQMPPGAKCwuNz+eLd2iWWr16tXE4HObVV181p06dCrePPvoo3GfVqlXG7Xabffv2mba2NuPxeIzH44lj1PF18d0uxqR3flpbW01WVpbZuHGjOX78uHnuuefMddddZ/74xz+G+2zatMkUFhaaF1980fz73/82CxcuTNlbSS9VXV1tPv3pT4dvtX3hhRfMhAkTzMMPPxzuk0756evrM0eOHDFHjhwxkswvf/lLc+TIkfDdGiPJxfz5883MmTNNS0uLOXTokJkyZUrK3Eo6XH4GBgbM3XffbSZNmmQ6OjoiPq/7+/vDx0jl/EQrIYsPY4z51a9+Zdxut8nJyTGVlZXmtddei3dIlpM0ZNu6dWu4z8cff2y+//3vm/Hjx5vrrrvOfOMb3zCnTp2KX9Bxdmnxke75eemll8y0adOMzWYzU6dONb/+9a8j9odCIbN+/XrjdDqNzWYzd9xxh+nq6opTtNYKBoNm7dq1xu12m9zcXPOZz3zGPPbYYxF/LNIpP/v37x/y86a6utoYM7Jc/Oc//zHLli0z+fn5xm63m/vuu8/09fXF4aeJveHy093dfcXP6/3794ePkcr5iVaGMRct5wcAADDGEu6aDwAAkNooPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKX+H9ZtKdLUYoywAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 0.0004, norm: 0.0002, param_norm: 938.9938, vb_loss: 0.0002, ce_loss: 0.1951: 77%|███████▋ | 1200/1563 [01:26<00:14, 25.02it/s]"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYT0lEQVR4nO3df0xV9/3H8ReIXGyRi2AAmVLZ2sx2/qhDRWoz+4NFXdP6K2s1bqI1a3TYSUlWazud6eJwW1Kxi7XZL+iyOp1JwdWtGoe/5oaoTNpaW2pTpqR6cZ2Fi7aChc/3j317cs9Vr1y8nHsv9/lITnJ+ce6H95tzeeecz+ecOGOMEQAAgEPiw90AAAAQWyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAo/qs+Ni0aZNGjhyppKQk5efn68iRI331UQAAIIrE9cW7XbZt26aFCxfq5ZdfVn5+vsrLy7V9+3Y1NjYqIyMj4M92d3fr7NmzGjx4sOLi4kLdNAAA0AeMMWpvb1d2drbi429wbcP0gUmTJpni4mJruaury2RnZ5uysrIb/mxzc7ORxMTExMTExBSFU3Nz8w3/14f8tktnZ6fq6+tVWFhorYuPj1dhYaFqa2uv2r+jo0Ner9eaDC/ZBQAgag0ePPiG+4S8+Pj444/V1dWlzMxM2/rMzEx5PJ6r9i8rK5Pb7bamnJycUDcJAAA4pCddJsI+2mXVqlVqa2uzpubm5nA3CQAA9KGEUB9w6NChGjBggFpaWmzrW1palJWVddX+LpdLLpcr1M0AAAARKuRXPhITE5WXl6eamhprXXd3t2pqalRQUBDqjwMAAFEm5Fc+JKm0tFRFRUWaMGGCJk2apPLycl26dEmLFy/ui48DAABRpE+Kj8cee0z/+c9/tGbNGnk8Ht19993atWvXVZ1QAQBA7OmTh4zdDK/XK7fbHe5mAACAXmhra1NKSkrAfcI+2gUAAMQWig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOCohHA3AKFVUVFhW66srLTmFy1aZNs2a9as6x5n7dq1tuWNGzfeZMsQSuQ5NgTKs2TPNXmObr657m2eJXuuIznPXPkAAACOovgAAACOCrr4OHjwoB5++GFlZ2crLi5O1dXVtu3GGK1Zs0bDhg3ToEGDVFhYqFOnToWqvQAAIMoF3efj0qVLGjdunB5//HHNmTPnqu0///nP9eKLL+qVV15Rbm6uVq9erWnTpunkyZNKSkoKSaOjlf/9W183uo/XU/v377ct9/ReYUNDQ8Bl9Jx/nu+77z5rPjU1NSSf0ds8S/bckufeC/f5fKPPIc+9Fyi3/voi18Hk2V+05Dro4mPGjBmaMWPGNbcZY1ReXq4f/ehHmjlzpiTp97//vTIzM1VdXa158+bdXGsBAEDUC2mfj6amJnk8HhUWFlrr3G638vPzVVtbe82f6ejokNfrtU0AAKD/Cmnx4fF4JEmZmZm29ZmZmdY2f2VlZXK73dY0YsSIUDYJAABEmDhjjOn1D8fFqaqqyrof9c9//lNTpkzR2bNnNWzYMGu/Rx99VHFxcdq2bdtVx+jo6FBHR4e17PV6o6oAaWpqsi2H6p5+b/nfKwxkx44d1vxtt91m23b69OlQNalf6C95luy5Js/SuHHjbMu+nejJc//le06HO89Sz3MdKM9SZOS6ra1NKSkpAfcJ6ZWPrKwsSVJLS4ttfUtLi7XNn8vlUkpKim0CAAD9V0iLj9zcXGVlZammpsZa5/V6VVdXp4KCglB+FAAAiFJB33a5ePGiPvjgA0nS+PHj9cILL+j+++9XWlqacnJy9LOf/Uzr16+3DbV96623ejzU1uv1yu129+63CQP/S16hGubke5zW1lbbNt+hm/6GDBkSks+H3Rejt77g//jj3iLP4Td16lTb8r///W9rvi/OZ8mea/Ic+T755JMe7Xf33XfblsvLy635QHmW+leue3LbJeihtseOHdP9999vLZeWlkqSioqKVFlZqaefflqXLl3SE088odbWVt17773atWtXzD/jAwAA/E/Qxcd9992nQBdL4uLi9Pzzz+v555+/qYYBAID+iXe7AAAARwV95QN2/sOafPsC+A/fCubx5r63tvz59jMpKSm5URMRAv7D2/zvz/rmJFA/gd7mWSLXfeXAgQPX3eZ/zvb2dQXkObr5nu/+/T98c+3//2D27NnWPHm248oHAABwFMUHAABwFMUHAABw1E09Xr0vRNtzPoJRVVVlzfuP+fZ/tK5/HwNED988S/Zck+f+gzwD1+b449UBAABuhOIDAAA4itsuYeI/XGv//v22Zd8hWug//C/Vk+f+w/ec5nxGLOO2CwAAiDgUHwAAwFEUHwAAwFE8Xj1M/O8J+w+99X3Nd6DHPyPy+fbz8L/3T577D99zOtD5LJHraBao3xZ57jmufAAAAEdRfAAAAEcx1DZM/N9w6P8WTN9LuK2trdfd9sorr4S4ZQi1QG+8DZRnf4sXLw5hqxBqPc2zZM/12rVrbdv834yKyFJRUWFb9n27MXn+H4baAgCAiEPxAQAAHEXxAQAAHEWfjwgxc+ZM23JlZWWPfq6kpMS2XF1dbVtua2u7iVYhnAI9gp/HdUcX/1wOGTLkutvIc/QKJs++2/ob+nwAAICIQ/EBAAAcRfEBAAAcRZ+PCOX7zIDy8nLbNv9HN/vyHXMu8Xjf/sT3nrH//WLfc4Z+PkB4+D+/qbfP8gjVccKFPh8AACDiBFV8lJWVaeLEiRo8eLAyMjI0a9YsNTY22va5fPmyiouLlZ6eruTkZM2dO1ctLS0hbTQAAIheQd12mT59uubNm6eJEyfq888/17PPPqsTJ07o5MmTuvXWWyVJy5Yt01/+8hdVVlbK7XZr+fLlio+P1z/+8Y8efQa3Xa7mHw/fy+r+w7f8MWyvf4qlYXv9EbfJ0J/15LZLQjAH3LVrl225srJSGRkZqq+v1ze+8Q21tbXpt7/9rbZs2aIHHnhA0v+eg3/nnXfq8OHDmjx5cpC/AgAA6G9uqs/HFxV7WlqaJKm+vl5XrlxRYWGhtc+oUaOUk5Oj2traax6jo6NDXq/XNgEAgP6r18VHd3e3SkpKNGXKFI0ePVqS5PF4lJiYqNTUVNu+mZmZ8ng81zxOWVmZ3G63NY0YMaK3TQIAAFEgqNsuvoqLi3XixAkdOnTophqwatUqlZaWWster5cCxM/N3BP2HZY7depU2zaG4UaXQP17fLcx3DryBTqnb9SPyxd9faJXrD+KvVfFx/Lly7Vz504dPHhQw4cPt9ZnZWWps7NTra2ttqsfLS0tysrKuuaxXC6XXC5Xb5oBAACiUFC3XYwxWr58uaqqqrR3717l5ubatufl5WngwIGqqamx1jU2NurMmTMqKCgITYsBAEBUC+rKR3FxsbZs2aIdO3Zo8ODBVj8Ot9utQYMGye12a8mSJSotLVVaWppSUlL05JNPqqCggJEuIVRVVdXjfRctWmTNc/k9ugST54aGBmuePEeX3p7PiC43yrPvrZZYyHNQxcfmzZslXf1474qKCitYGzZsUHx8vObOnauOjg5NmzZNL730UkgaCwAAol9QxUdPnkeWlJSkTZs2adOmTb1uFAAA6L94twsAAHBUr4fa4ub4v7Vw5MiRtuWSkhJrPtBbbG/E997hjh07en0chIb/cOdQ5bm1tbXXP4vQC5Rnqfe55hyOPL65rqystG3zf+ZVT/n3+eiPeefKBwAAcBTFBwAAcBTFBwAAcBR9PhzU1NRkzfvf06uuru7VMX2f7yBJa9eutS3zzAfn+eZZsue6t3n2xyPUw6+v8uzbbyBUfy+4OU5/d8fC+cyVDwAA4CiKDwAA4Kg405MnhznI6/XK7XaHuxm95j/EzleoLs+Vl5db8/1xCFY0cDrPErkOh77Is2TPNXkOv0B5lkJzTsdSntva2pSSkhJwH658AAAAR1F8AAAAR1F8AAAARzHUthcqKiqsef/H6dKvo/8gz7HBN89SaIa60n8nMmzYsMGa93/UeaiGNMdqv46bxZUPAADgKIoPAADgKIba/r+qqipr3v8Nof6X627m7aNf2L9/v2159uzZN31M3Jj/Uyl989AXefb/DPLsjHHjxtmWfd8qS577D9/vbSnwd3df5Fki19fCUFsAABBxKD4AAICjKD4AAICj6PNxDZ988kmP9/V/u6jvvWV/3BuMPD3NdTB59n/rZVtbW5CtQqj19pzmfI4ufZFnyX5Ocz7fGH0+AABAxKH4AAAAjqL4AAAAjqLPBwAACBn6fAAAgIgTVPGxefNmjR07VikpKUpJSVFBQYHeeOMNa/vly5dVXFys9PR0JScna+7cuWppaQl5owEAQPQKqvgYPny41q9fr/r6eh07dkwPPPCAZs6cqXfeeUeS9NRTT+n111/X9u3bdeDAAZ09e1Zz5szpk4YDAIAoZW7SkCFDzG9+8xvT2tpqBg4caLZv325te/fdd40kU1tb2+PjtbW1GUlMTExMTExMUTi1tbXd8H99r/t8dHV1aevWrbp06ZIKCgpUX1+vK1euqLCw0Npn1KhRysnJUW1t7XWP09HRIa/Xa5sAAED/FXTx8fbbbys5OVkul0tLly5VVVWV7rrrLnk8HiUmJl71xsjMzEx5PJ7rHq+srExut9uaRowYEfQvAQAAokfQxcdXv/pVNTQ0qK6uTsuWLVNRUZFOnjzZ6wasWrVKbW1t1tTc3NzrYwEAgMiXEOwPJCYm6vbbb5ck5eXl6ejRo9q4caMee+wxdXZ2qrW11Xb1o6WlRVlZWdc9nsvlksvlCr7lAAAgKt30cz66u7vV0dGhvLw8DRw4UDU1Nda2xsZGnTlzRgUFBTf7MQAAoJ8I6srHqlWrNGPGDOXk5Ki9vV1btmzR/v37tXv3brndbi1ZskSlpaVKS0tTSkqKnnzySRUUFGjy5Ml91X4AABBlgio+zp8/r4ULF+rcuXNyu90aO3asdu/erW9+85uSpA0bNig+Pl5z585VR0eHpk2bppdeeqlPGg4AAKIT73YBAAAhw7tdAABAxKH4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjoq44iPCXjUDAACC0JP/4xFXfLS3t4e7CQAAoJd68n884t5q293drbNnz8oYo5ycHDU3N9/w7XixyOv1asSIEcTnOohPYMQnMOITGPG5vliOjTFG7e3tys7OVnx84GsbCQ61qcfi4+M1fPhweb1eSVJKSkrMJTAYxCcw4hMY8QmM+ARGfK4vVmPjdrt7tF/E3XYBAAD9G8UHAABwVMQWHy6XSz/+8Y/lcrnC3ZSIRHwCIz6BEZ/AiE9gxOf6iE3PRFyHUwAA0L9F7JUPAADQP1F8AAAAR1F8AAAAR1F8AAAAR1F8AAAAR0Vs8bFp0yaNHDlSSUlJys/P15EjR8LdJMeVlZVp4sSJGjx4sDIyMjRr1iw1Njba9rl8+bKKi4uVnp6u5ORkzZ07Vy0tLWFqcXitX79ecXFxKikpsdbFenw++ugjfec731F6eroGDRqkMWPG6NixY9Z2Y4zWrFmjYcOGadCgQSosLNSpU6fC2GLndHV1afXq1crNzdWgQYP0la98RT/5yU9sL8WKpfgcPHhQDz/8sLKzsxUXF6fq6mrb9p7E4sKFC1qwYIFSUlKUmpqqJUuW6OLFiw7+Fn0nUHyuXLmilStXasyYMbr11luVnZ2thQsX6uzZs7Zj9Of4BM1EoK1bt5rExETzu9/9zrzzzjvme9/7nklNTTUtLS3hbpqjpk2bZioqKsyJEydMQ0OD+da3vmVycnLMxYsXrX2WLl1qRowYYWpqasyxY8fM5MmTzT333BPGVofHkSNHzMiRI83YsWPNihUrrPWxHJ8LFy6Y2267zSxatMjU1dWZDz/80Ozevdt88MEH1j7r1683brfbVFdXmzfffNM88sgjJjc313z22WdhbLkz1q1bZ9LT083OnTtNU1OT2b59u0lOTjYbN2609oml+Pz1r381zz33nHnttdeMJFNVVWXb3pNYTJ8+3YwbN84cPnzY/P3vfze33367mT9/vsO/Sd8IFJ/W1lZTWFhotm3bZt577z1TW1trJk2aZPLy8mzH6M/xCVZEFh+TJk0yxcXF1nJXV5fJzs42ZWVlYWxV+J0/f95IMgcOHDDG/O8PfuDAgWb79u3WPu+++66RZGpra8PVTMe1t7ebO+64w+zZs8dMnTrVKj5iPT4rV640995773W3d3d3m6ysLPOLX/zCWtfa2mpcLpf54x//6EQTw+qhhx4yjz/+uG3dnDlzzIIFC4wxsR0f/3+uPYnFyZMnjSRz9OhRa5833njDxMXFmY8++sixtjvhWsWZvyNHjhhJ5vTp08aY2IpPT0TcbZfOzk7V19ersLDQWhcfH6/CwkLV1taGsWXh19bWJklKS0uTJNXX1+vKlSu2WI0aNUo5OTkxFavi4mI99NBDtjhIxOfPf/6zJkyYoG9/+9vKyMjQ+PHj9etf/9ra3tTUJI/HY4uP2+1Wfn5+TMTnnnvuUU1Njd5//31J0ptvvqlDhw5pxowZkoiPr57Eora2VqmpqZowYYK1T2FhoeLj41VXV+d4m8Otra1NcXFxSk1NlUR8/EXcW20//vhjdXV1KTMz07Y+MzNT7733XphaFX7d3d0qKSnRlClTNHr0aEmSx+NRYmKi9cf9hczMTHk8njC00nlbt27Vv/71Lx09evSqbbEenw8//FCbN29WaWmpnn32WR09elQ/+MEPlJiYqKKiIisG1zrXYiE+zzzzjLxer0aNGqUBAwaoq6tL69at04IFCyQp5uPjqyex8Hg8ysjIsG1PSEhQWlpazMXr8uXLWrlypebPn2+92Zb42EVc8YFrKy4u1okTJ3To0KFwNyViNDc3a8WKFdqzZ4+SkpLC3ZyI093drQkTJuinP/2pJGn8+PE6ceKEXn75ZRUVFYW5deH3pz/9Sa+++qq2bNmir33ta2poaFBJSYmys7OJD3rtypUrevTRR2WM0ebNm8PdnIgVcbddhg4dqgEDBlw1IqGlpUVZWVlhalV4LV++XDt37tS+ffs0fPhwa31WVpY6OzvV2tpq2z9WYlVfX6/z58/r61//uhISEpSQkKADBw7oxRdfVEJCgjIzM2M6PsOGDdNdd91lW3fnnXfqzJkzkmTFIFbPtR/+8Id65plnNG/ePI0ZM0bf/e539dRTT6msrEwS8fHVk1hkZWXp/Pnztu2ff/65Lly4EDPx+qLwOH36tPbs2WNd9ZCIj7+IKz4SExOVl5enmpoaa113d7dqampUUFAQxpY5zxij5cuXq6qqSnv37lVubq5te15engYOHGiLVWNjo86cORMTsXrwwQf19ttvq6GhwZomTJigBQsWWPOxHJ8pU6ZcNTT7/fff12233SZJys3NVVZWli0+Xq9XdXV1MRGfTz/9VPHx9q/AAQMGqLu7WxLx8dWTWBQUFKi1tVX19fXWPnv37lV3d7fy8/Mdb7PTvig8Tp06pb/97W9KT0+3bY/1+Fwl3D1er2Xr1q3G5XKZyspKc/LkSfPEE0+Y1NRU4/F4wt00Ry1btsy43W6zf/9+c+7cOWv69NNPrX2WLl1qcnJyzN69e82xY8dMQUGBKSgoCGOrw8t3tIsxsR2fI0eOmISEBLNu3Tpz6tQp8+qrr5pbbrnF/OEPf7D2Wb9+vUlNTTU7duwwb731lpk5c2a/HUrqr6ioyHzpS1+yhtq+9tprZujQoebpp5+29oml+LS3t5vjx4+b48ePG0nmhRdeMMePH7dGa/QkFtOnTzfjx483dXV15tChQ+aOO+7oN0NJA8Wns7PTPPLII2b48OGmoaHB9n3d0dFhHaM/xydYEVl8GGPML3/5S5OTk2MSExPNpEmTzOHDh8PdJMdJuuZUUVFh7fPZZ5+Z73//+2bIkCHmlltuMbNnzzbnzp0LX6PDzL/4iPX4vP7662b06NHG5XKZUaNGmV/96le27d3d3Wb16tUmMzPTuFwu8+CDD5rGxsYwtdZZXq/XrFixwuTk5JikpCTz5S9/2Tz33HO2fxaxFJ99+/Zd8/umqKjIGNOzWPz3v/818+fPN8nJySYlJcUsXrzYtLe3h+G3Cb1A8Wlqarru9/W+ffusY/Tn+AQrzhifx/kBAAD0sYjr8wEAAPo3ig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOCo/wNZvNGki+7vXAAAAABJRU5ErkJggg==",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"loss: 0.0005, norm: 0.0005, param_norm: 939.0556, vb_loss: 0.0001, ce_loss: 0.1722: 81%|████████ | 1260/1563 [01:37<00:23, 12.98it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m norm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(d3pm\u001b[38;5;241m.\u001b[39mx0_model\u001b[38;5;241m.\u001b[39mparameters(), \u001b[38;5;241m0.01\u001b[39m)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 31\u001b[0m param_norm \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m([\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m d3pm\u001b[38;5;241m.\u001b[39mx0_model\u001b[38;5;241m.\u001b[39mparameters()])\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m loss_ema \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 33\u001b[0m loss_ema \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n",
"File \u001b[0;32m~/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/functional.py:1610\u001b[0m, in \u001b[0;36mnorm\u001b[0;34m(input, p, dim, keepdim, out, dtype)\u001b[0m\n\u001b[1;32m 1608\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m p \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfro\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dim, (\u001b[38;5;28mint\u001b[39m, torch\u001b[38;5;241m.\u001b[39mSymInt)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(dim) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m 1609\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m out \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1610\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinalg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvector_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1611\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1612\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mvector_norm(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m2\u001b[39m, _dim, keepdim, dtype\u001b[38;5;241m=\u001b[39mdtype, out\u001b[38;5;241m=\u001b[39mout)\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"d3pm.train()\n",
"\n",
"n_epoch = 10\n",
"device = 'cuda'\n",
"from torchvision.utils import make_grid\n",
"from matplotlib import pyplot as plt\n",
"global_step = 0\n",
"for i in range(n_epoch):\n",
" d3pm.train()\n",
" pbar = tqdm(dataloader)\n",
" loss_ema = None\n",
" for x, _ in pbar:\n",
" #print(x)\n",
" optim.zero_grad()\n",
" x = x.to(device)\n",
" # discritize x to 10 bins\n",
" x = (x * N).long().clamp(0, N - 1)\n",
" \n",
" loss, info = d3pm(x)\n",
" \n",
" # if loss.item() > 1000 or torch.isnan(loss):\n",
" # print(f\"loss is too high, skipping, {loss.item()}\")\n",
" # loss.backward()\n",
" # optim.zero_grad()\n",
" # continue\n",
" #print(loss.item())\n",
" loss.backward()\n",
" norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.01)\n",
"\n",
" with torch.no_grad():\n",
" param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])\n",
" if loss_ema is None:\n",
" loss_ema = loss.item()\n",
" else:\n",
" loss_ema = 0.99 * loss_ema + 0.01 * loss.item()\n",
" pbar.set_description(f\"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}\")\n",
" optim.step()\n",
" global_step += 1\n",
"\n",
" if global_step % 300 == 1:\n",
" d3pm.eval()\n",
"\n",
" with torch.no_grad():\n",
" x = d3pm.sample(torch.randint(0, N, (4, 1, 32, 32)).cuda())\n",
" x_as_image = make_grid(x.float() / N, nrow=4)\n",
" plt.figure()\n",
" plt.imshow(x_as_image.permute(1, 2, 0).cpu().numpy())\n",
" plt.show()\n",
"\n",
" d3pm.train()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeuElEQVR4nO3dfWzV5f3/8Ve56RGlPbXc9EYKFlCYQHFjUDsVUSqlZoSbLsGbbDCZBFaI0HlXoyJufsvwHoNo4gJxEXEsAsNEUIotcRYclQ5R1wDpBEZbJlvPKcUW1l6/P5adn5W782lPefeU5yO5Es7n8+513p9csS8/55xeJ8Y55wQAwEXWzboBAMCliQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiR7WDXxXS0uLjh49qri4OMXExFi3AwDwyDmn+vp6paamqlu3c9/ndLoAOnr0qNLS0qzbAAC00+HDhzVgwIBznu+wl+BWrlypq6++WpdddpkyMzP1ySefhPVzcXFxHdUSAOAiutDv8w4JoLffflsFBQVasmSJPv30U40ePVo5OTk6duzYBX+Wl90AoGu44O9z1wHGjRvn8vPzQ4+bm5tdamqqKyoquuDPBgIBJ4nBYDAYUT4CgcB5f99H/A7o1KlTKi8vV3Z2duhYt27dlJ2drbKysjPqm5qaFAwGWw0AQNcX8QD6+uuv1dzcrKSkpFbHk5KSVFNTc0Z9UVGR/H5/aPABBAC4NJj/HVBhYaECgUBoHD582LolAMBFEPGPYfft21fdu3dXbW1tq+O1tbVKTk4+o97n88nn80W6DQBAJxfxO6DY2FiNGTNGxcXFoWMtLS0qLi5WVlZWpJ8OABClOuQPUQsKCjRr1iz98Ic/1Lhx4/Tiiy+qoaFBP//5zzvi6QAAUahDAmjmzJn65z//qSeeeEI1NTW6/vrrtWXLljM+mAAAuHTFOOecdRPfFgwG5ff7rdsAALRTIBBQfHz8Oc+bfwoOAHBpIoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAICJiAfQk08+qZiYmFZj+PDhkX4aAECU69ERk44YMULbtm37/0/So0OeBgAQxTokGXr06KHk5OSOmBoA0EV0yHtA+/fvV2pqqgYPHqx77rlHhw4dOmdtU1OTgsFgqwEA6PoiHkCZmZlas2aNtmzZolWrVqmqqko333yz6uvrz1pfVFQkv98fGmlpaZFuCQDQCcU451xHPkFdXZ0GDRqk559/XnPmzDnjfFNTk5qamkKPg8EgIQQAXUAgEFB8fPw5z3f4pwMSEhJ07bXX6sCBA2c97/P55PP5OroNAEAn0+F/B3TixAkdPHhQKSkpHf1UAIAoEvEAeuCBB1RaWqq///3v+vjjjzV9+nR1795dd911V6SfCgAQxSL+EtyRI0d011136fjx4+rXr59uuukm7dy5U/369Yv0UwEAoliHfwjBq2AwKL/fb90GAKCdLvQhBPaCAwCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJjr86xiAaNW9e/ewazvT9lELFizwVH/55ZeHXTts2DBPc+fn54dd++yzz3qa28sGx42NjZ7mXrZsmaf6pUuXeqrHf3EHBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATLAVDzrcwIEDw66NjY31NPePfvSjsGtvuukmT3MnJCSEXZuXl+dp7mh15MgRT/UrVqwIu3b69Ome5q6vrw+79q9//aunuUtLSz3Vo224AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiRjnnLNu4tuCwaD8fr91GziP73//+57qi4uLw65l7S++lpaWsGvvvfdeT3M3NDR4bSdsR48eDbv23//+t6e5KysrvbaDswgEAoqPjz/nee6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCih3UDiD5fffWVp/rjx4+HXXup7AW3a9cuT/V1dXVh1956662e5j516lTYtb///e89zQ2cD3dAAAATngNox44dmjJlilJTUxUTE6ONGze2Ou+c0xNPPKGUlBT16tVL2dnZ2r9/f6T6BQB0EZ4DqKGhQaNHj9bKlSvPen758uVasWKFXn31Ve3atUtXXHGFcnJy1NjY2O5mAQBdh+f3gHJzc5Wbm3vWc845vfjii3rsscc0depUSdIbb7yhpKQkbdy4UXfeeWf7ugUAdBkRfQ+oqqpKNTU1ys7ODh3z+/3KzMxUWVnZWX+mqalJwWCw1QAAdH0RDaCamhpJUlJSUqvjSUlJoXPfVVRUJL/fHxppaWmRbAkA0EmZfwqusLBQgUAgNA4fPmzdEgDgIohoACUnJ0uSamtrWx2vra0Nnfsun8+n+Pj4VgMA0PVFNIDS09OVnJys4uLi0LFgMKhdu3YpKysrkk8FAIhynj8Fd+LECR04cCD0uKqqShUVFUpMTNTAgQO1aNEi/eY3v9E111yj9PR0Pf7440pNTdW0adMi2TcAIMp5DqDdu3e32uqjoKBAkjRr1iytWbNGDz30kBoaGjR37lzV1dXppptu0pYtW3TZZZdFrmuY+te//uWp/sEHHwy79sc//rGnuffs2RN27YoVKzzN7UVFRYWn+ttvv91TfUNDQ9i1I0aM8DT3/fff76keiBTPATRhwgQ55855PiYmRk899ZSeeuqpdjUGAOjazD8FBwC4NBFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMx7nz76hgIBoPy+/3WbcCI16/jqK+vD7v2tdde8zT3nDlzwq796U9/6mnutWvXeqoHolEgEDjvf9PcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABM9rBsAvi0YDHbY3IFAoMPm/sUvfuGpft26dZ7qW1paPNUD0YA7IACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYiHHOOesmvi0YDMrv91u3gS7oiiuu8FS/efPmsGtvueUWT3Pn5uZ6qn///fc91QOdQSAQUHx8/DnPcwcEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBUPcA5DhgwJu/bTTz/1NHddXZ2n+g8//DDs2t27d3uae+XKlWHXdrJfF+jk2IoHANApEUAAABOeA2jHjh2aMmWKUlNTFRMTo40bN7Y6P3v2bMXExLQakydPjlS/AIAuwnMANTQ0aPTo0ed93Xjy5Mmqrq4OjbfeeqtdTQIAup4eXn8gNzf3gt9l4vP5lJyc3OamAABdX4e8B1RSUqL+/ftr2LBhmj9/vo4fP37O2qamJgWDwVYDAND1RTyAJk+erDfeeEPFxcX67W9/q9LSUuXm5qq5ufms9UVFRfL7/aGRlpYW6ZYAAJ2Q55fgLuTOO+8M/XvUqFHKyMjQkCFDVFJSookTJ55RX1hYqIKCgtDjYDBICAHAJaDDP4Y9ePBg9e3bVwcOHDjreZ/Pp/j4+FYDAND1dXgAHTlyRMePH1dKSkpHPxUAIIp4fgnuxIkTre5mqqqqVFFRocTERCUmJmrp0qXKy8tTcnKyDh48qIceekhDhw5VTk5ORBsHAEQ3z3vBlZSU6NZbbz3j+KxZs7Rq1SpNmzZNe/bsUV1dnVJTUzVp0iT9+te/VlJSUljzsxccotH06dM91a9evdpTfVxcnKd6Lx599NGwa9944w1Pc1dXV3ttB13IhfaC83wHNGHChPNuSLh161avUwIALkHsBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEx43guuo7EXHC4Fo0aN8lT/3HPPhV17tu/dipTXXnvNU/3TTz8ddu0//vEPr+2gk7vQXnDcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABNsxQNEgYSEhLBrp0yZ4mnu1atXh10bExPjae7t27eHXXv77bd7mhudH1vxAAA6JQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYC844BLX1NQUdm2PHj08zf2f//wn7NqcnBxPc5eUlHiqx8XHXnAAgE6JAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY8LavBoCIyMjI8FT/k5/8JOzasWPHeprb6/Y6XnzxxRdh1+7YsaPD+kDnxB0QAMCEpwAqKirS2LFjFRcXp/79+2vatGmqrKxsVdPY2Kj8/Hz16dNHvXv3Vl5enmprayPaNAAg+nkKoNLSUuXn52vnzp364IMPdPr0aU2aNEkNDQ2hmsWLF2vz5s1av369SktLdfToUc2YMSPijQMAopunF3+3bNnS6vGaNWvUv39/lZeXa/z48QoEAvrd736ntWvX6rbbbpMkrV69Wt/73ve0c+dO3XDDDZHrHAAQ1dr1HlAgEJAkJSYmSpLKy8t1+vRpZWdnh2qGDx+ugQMHqqys7KxzNDU1KRgMthoAgK6vzQHU0tKiRYsW6cYbb9TIkSMlSTU1NYqNjVVCQkKr2qSkJNXU1Jx1nqKiIvn9/tBIS0tra0sAgCjS5gDKz8/Xvn37tG7dunY1UFhYqEAgEBqHDx9u13wAgOjQpj8AWLBggd59913t2LFDAwYMCB1PTk7WqVOnVFdX1+ouqLa2VsnJyWedy+fzyefztaUNAEAU83QH5JzTggULtGHDBm3fvl3p6emtzo8ZM0Y9e/ZUcXFx6FhlZaUOHTqkrKysyHQMAOgSPN0B5efna+3atdq0aZPi4uJC7+v4/X716tVLfr9fc+bMUUFBgRITExUfH6+FCxcqKyuLT8ABAFrxFECrVq2SJE2YMKHV8dWrV2v27NmSpBdeeEHdunVTXl6empqalJOTo1deeSUizQIAuo4Y55yzbuLbgsGg/H6/dRuAhg0bFnbtwoULPc09ffp0T/Xneg/1YmtubvZUv23btrBr77jjDq/toJMLBAKKj48/53n2ggMAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACba9HUMQGfhZYuau+++29Pc+fn5YddeffXVnubuTHbv3h127dNPP+1p7j/96U9e28ElhDsgAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgLzh0uKSkpLBrR4wY4Wnul19+Oeza4cOHe5q7M9m1a1fYtc8884ynuTdt2hR2bUtLi6e5gfPhDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgKx4oMTHRU/1rr73mqf76668Pu3bw4MGe5u4sPv74Y0/1zz33nKf6rVu3hl37zTffeJobsMIdEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBdclMjMzPRU/+CDD4ZdO27cOE9zX3XVVZ7qOwuve6S99NJLYdf+3//9n6e5GxoaPNUDXRF3QAAAE54CqKioSGPHjlVcXJz69++vadOmqbKyslXNhAkTFBMT02rMmzcvok0DAKKfpwAqLS1Vfn6+du7cqQ8++ECnT5/WpEmTzng54b777lN1dXVoLF++PKJNAwCin6f3gLZs2dLq8Zo1a9S/f3+Vl5dr/PjxoeOXX365kpOTI9MhAKBLatd7QIFAQNKZX2j25ptvqm/fvho5cqQKCwt18uTJc87R1NSkYDDYagAAur42fwqupaVFixYt0o033qiRI0eGjt99990aNGiQUlNTtXfvXj388MOqrKzUO++8c9Z5ioqKtHTp0ra2AQCIUm0OoPz8fO3bt08fffRRq+Nz584N/XvUqFFKSUnRxIkTdfDgQQ0ZMuSMeQoLC1VQUBB6HAwGlZaW1ta2AABRok0BtGDBAr377rvasWOHBgwYcN7a//39yoEDB84aQD6fTz6fry1tAACimKcAcs5p4cKF2rBhg0pKSpSenn7Bn6moqJAkpaSktKlBAEDX5CmA8vPztXbtWm3atElxcXGqqamRJPn9fvXq1UsHDx7U2rVrdccdd6hPnz7au3evFi9erPHjxysjI6NDLgAAEJ08BdCqVask/fePTb9t9erVmj17tmJjY7Vt2za9+OKLamhoUFpamvLy8vTYY49FrGEAQNfg+SW480lLS1NpaWm7GsLZTZ8+vUPrO9KXX34Zdu3mzZs9zd3c3Bx27bPPPutp7rq6Ok/1ALxhLzgAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGAixl1of52LLBgMyu/3W7cBAGinQCCg+Pj4c57nDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJTwG0atUqZWRkKD4+XvHx8crKytJ7770XOt/Y2Kj8/Hz16dNHvXv3Vl5enmprayPeNAAg+nkKoAEDBmjZsmUqLy/X7t27ddttt2nq1Kn6/PPPJUmLFy/W5s2btX79epWWluro0aOaMWNGhzQOAIhyrp2uvPJK9/rrr7u6ujrXs2dPt379+tC5L7/80klyZWVlYc8XCAScJAaDwWBE+QgEAuf9fd/m94Cam5u1bt06NTQ0KCsrS+Xl5Tp9+rSys7NDNcOHD9fAgQNVVlZ2znmampoUDAZbDQBA1+c5gD777DP17t1bPp9P8+bN04YNG3TdddeppqZGsbGxSkhIaFWflJSkmpqac85XVFQkv98fGmlpaZ4vAgAQfTwH0LBhw1RRUaFdu3Zp/vz5mjVrlr744os2N1BYWKhAIBAahw8fbvNcAIDo0cPrD8TGxmro0KGSpDFjxugvf/mLXnrpJc2cOVOnTp1SXV1dq7ug2tpaJScnn3M+n88nn8/nvXMAQFRr998BtbS0qKmpSWPGjFHPnj1VXFwcOldZWalDhw4pKyurvU8DAOhiPN0BFRYWKjc3VwMHDlR9fb3Wrl2rkpISbd26VX6/X3PmzFFBQYESExMVHx+vhQsXKisrSzfccENH9Q8AiFKeAujYsWP62c9+purqavn9fmVkZGjr1q26/fbbJUkvvPCCunXrpry8PDU1NSknJ0evvPJKhzQOAIhuMc45Z93EtwWDQfn9fus2AADtFAgEFB8ff87z7AUHADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMNHpAqiTbcwAAGijC/0+73QBVF9fb90CACACLvT7vNPtBdfS0qKjR48qLi5OMTExoePBYFBpaWk6fPjwefcWinZcZ9dxKVyjxHV2NZG4Tuec6uvrlZqaqm7dzn2f4/kL6Tpat27dNGDAgHOej4+P79KL/z9cZ9dxKVyjxHV2Ne29znA2le50L8EBAC4NBBAAwETUBJDP59OSJUvk8/msW+lQXGfXcSlco8R1djUX8zo73YcQAACXhqi5AwIAdC0EEADABAEEADBBAAEATERNAK1cuVJXX321LrvsMmVmZuqTTz6xbiminnzyScXExLQaw4cPt26rXXbs2KEpU6YoNTVVMTEx2rhxY6vzzjk98cQTSklJUa9evZSdna39+/fbNNsOF7rO2bNnn7G2kydPtmm2jYqKijR27FjFxcWpf//+mjZtmiorK1vVNDY2Kj8/X3369FHv3r2Vl5en2tpao47bJpzrnDBhwhnrOW/ePKOO22bVqlXKyMgI/bFpVlaW3nvvvdD5i7WWURFAb7/9tgoKCrRkyRJ9+umnGj16tHJycnTs2DHr1iJqxIgRqq6uDo2PPvrIuqV2aWho0OjRo7Vy5cqznl++fLlWrFihV199Vbt27dIVV1yhnJwcNTY2XuRO2+dC1ylJkydPbrW2b7311kXssP1KS0uVn5+vnTt36oMPPtDp06c1adIkNTQ0hGoWL16szZs3a/369SotLdXRo0c1Y8YMw669C+c6Jem+++5rtZ7Lly836rhtBgwYoGXLlqm8vFy7d+/WbbfdpqlTp+rzzz+XdBHX0kWBcePGufz8/NDj5uZml5qa6oqKigy7iqwlS5a40aNHW7fRYSS5DRs2hB63tLS45ORk98wzz4SO1dXVOZ/P59566y2DDiPju9fpnHOzZs1yU6dONemnoxw7dsxJcqWlpc65/65dz5493fr160M1X375pZPkysrKrNpst+9ep3PO3XLLLe7++++3a6qDXHnlle7111+/qGvZ6e+ATp06pfLycmVnZ4eOdevWTdnZ2SorKzPsLPL279+v1NRUDR48WPfcc48OHTpk3VKHqaqqUk1NTat19fv9yszM7HLrKkklJSXq37+/hg0bpvnz5+v48ePWLbVLIBCQJCUmJkqSysvLdfr06VbrOXz4cA0cODCq1/O71/k/b775pvr27auRI0eqsLBQJ0+etGgvIpqbm7Vu3To1NDQoKyvroq5lp9uM9Lu+/vprNTc3KykpqdXxpKQk/e1vfzPqKvIyMzO1Zs0aDRs2TNXV1Vq6dKluvvlm7du3T3FxcdbtRVxNTY0knXVd/3euq5g8ebJmzJih9PR0HTx4UI8++qhyc3NVVlam7t27W7fnWUtLixYtWqQbb7xRI0eOlPTf9YyNjVVCQkKr2mhez7NdpyTdfffdGjRokFJTU7V37149/PDDqqys1DvvvGPYrXefffaZsrKy1NjYqN69e2vDhg267rrrVFFRcdHWstMH0KUiNzc39O+MjAxlZmZq0KBB+sMf/qA5c+YYdob2uvPOO0P/HjVqlDIyMjRkyBCVlJRo4sSJhp21TX5+vvbt2xf171FeyLmuc+7cuaF/jxo1SikpKZo4caIOHjyoIUOGXOw222zYsGGqqKhQIBDQH//4R82aNUulpaUXtYdO/xJc37591b179zM+gVFbW6vk5GSjrjpeQkKCrr32Wh04cMC6lQ7xv7W71NZVkgYPHqy+fftG5douWLBA7777rj788MNWX5uSnJysU6dOqa6urlV9tK7nua7zbDIzMyUp6tYzNjZWQ4cO1ZgxY1RUVKTRo0frpZdeuqhr2ekDKDY2VmPGjFFxcXHoWEtLi4qLi5WVlWXYWcc6ceKEDh48qJSUFOtWOkR6erqSk5NbrWswGNSuXbu69LpK0pEjR3T8+PGoWlvnnBYsWKANGzZo+/btSk9Pb3V+zJgx6tmzZ6v1rKys1KFDh6JqPS90nWdTUVEhSVG1nmfT0tKipqami7uWEf1IQwdZt26d8/l8bs2aNe6LL75wc+fOdQkJCa6mpsa6tYj51a9+5UpKSlxVVZX785//7LKzs13fvn3dsWPHrFtrs/r6erdnzx63Z88eJ8k9//zzbs+ePe6rr75yzjm3bNkyl5CQ4DZt2uT27t3rpk6d6tLT090333xj3Lk357vO+vp698ADD7iysjJXVVXltm3b5n7wgx+4a665xjU2Nlq3Hrb58+c7v9/vSkpKXHV1dWicPHkyVDNv3jw3cOBAt337drd7926XlZXlsrKyDLv27kLXeeDAAffUU0+53bt3u6qqKrdp0yY3ePBgN378eOPOvXnkkUdcaWmpq6qqcnv37nWPPPKIi4mJce+//75z7uKtZVQEkHPOvfzyy27gwIEuNjbWjRs3zu3cudO6pYiaOXOmS0lJcbGxse6qq65yM2fOdAcOHLBuq10+/PBDJ+mMMWvWLOfcfz+K/fjjj7ukpCTn8/ncxIkTXWVlpW3TbXC+6zx58qSbNGmS69evn+vZs6cbNGiQu++++6Luf57Odn2S3OrVq0M133zzjfvlL3/prrzySnf55Ze76dOnu+rqarum2+BC13no0CE3fvx4l5iY6Hw+nxs6dKh78MEHXSAQsG3co3vvvdcNGjTIxcbGun79+rmJEyeGwse5i7eWfB0DAMBEp38PCADQNRFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDx/wAzgwe7F079ZQAAAABJRU5ErkJggg==",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# visualize first dataset\n",
"x = next(iter(dataloader))[0][:1]\n",
"x_as_image = make_grid(x.float(), nrow=1)\n",
"plt.figure()\n",
"plt.imshow(x_as_image.permute(1, 2, 0).cpu().numpy())\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABkQAAAMtCAYAAADOtR3+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAq9ElEQVR4nO3df2yV9dn48avABGQtGxILlSpuyzYbYxuhEIZbRPtI1JBhtojbMhA3l23VSaoukiWwZSz+yGKI46jbMsU5psRlMLMsNYwtMA2TCmFh63RTWazD8kO3FjpFpef5w9jnywSl/bY9nIvXKzl/nPuc+9zX/Qcf2rx7n7uiWCwWAwAAAAAAILERpR4AAAAAAABgqAkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJDeqFIP0F+9vb2xe/fuqKysjIqKilKPAwAAAAAAlFCxWIwDBw5ETU1NjBhx7OtAyi6I7N69O2pra0s9BgAAAAAAcALp6OiIKVOmHPP1sgsilZWVEfHWiVVVVZV4GgAAAAAAoJS6u7ujtra2rx8cS9kFkbe/JquqqkoQAQAAAAAAIiLe8zYbbqoOAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6wx5EOjo64sILL4y6uro477zz4pFHHhnuEQAAAAAAgJPMqGE/4KhRsXLlymhoaIjOzs6YNm1aXHbZZTFu3LjhHgUAAAAAADhJDHsQmTx5ckyePDkiIiZNmhQTJ06MV155RRABAAAAAACGTL+/Mmvz5s0xb968qKmpiYqKili/fv073lMoFGLq1KkxZsyYmDlzZmzduvWon7Vt27Y4fPhw1NbW9ntwAAAAAACA49XvINLT0xP19fVRKBSO+vratWujpaUlli9fHtu3b4/6+vqYO3du7N2794j3vfLKK7Fw4cL40Y9+9K7HO3ToUHR3dx/xAAAAAAAA6I+KYrFYHPDOFRWxbt26mD9/ft+2mTNnRmNjY6xatSoiInp7e6O2tjauv/76uOWWWyLircjxP//zP3HttdfGF7/4xXc9xre//e34zne+847tXV1dUVVVNdDRAQCABF685Q8D2m/KbZ8c5EkAAIBS6e7ujvHjx79nN+j3FSLv5vXXX49t27ZFU1PT/x1gxIhoamqKLVu2REREsViMq6++Oi666KL3jCEREUuXLo2urq6+R0dHx2CODAAAAAAAnAQGNYjs378/Dh8+HNXV1Udsr66ujs7OzoiIeOKJJ2Lt2rWxfv36aGhoiIaGhti5c+cxP3P06NFRVVV1xAMAAAAAAKA/Rg33AS+44ILo7e0d7sMCAAAAAAAnsUG9QmTixIkxcuTI2LNnzxHb9+zZE5MmTRrMQwEAAAAAABy3QQ0ip5xySkybNi02btzYt623tzc2btwYs2bNGsxDAQAAAAAAHLd+f2XWwYMH49lnn+17vmvXrtixY0dMmDAhzjzzzGhpaYlFixbF9OnTY8aMGbFy5cro6emJxYsXD+rgAAAAAAAAx6vfQeSpp56KOXPm9D1vaWmJiIhFixbF6tWrY8GCBbFv375YtmxZdHZ2RkNDQ7S2tr7jRusAAAAAAADDpaJYLBZLPUR/dHd3x/jx46OrqyuqqqpKPQ4AAFBCL97yhwHtN+W2Tw7yJAAAQKkcbzcY1HuIAAAAAAAAnIgEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEivbIJIoVCIurq6aGxsLPUoAAAAAABAmSmbINLc3Bzt7e3R1tZW6lEAAAAAAIAyUzZBBAAAAAAAYKAEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACC9sgkihUIh6urqorGxsdSjAAAAAAAAZaZsgkhzc3O0t7dHW1tbqUcBAAAAAADKTNkEEQAAAAAAgIESRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAIL2yCSKFQiHq6uqisbGx1KMAAAAAAABlpmyCSHNzc7S3t0dbW1upRwEAAAAAAMpM2QQRAAAAAACAgRJEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANIrmyBSKBSirq4uGhsbSz0KAAAAAABQZsomiDQ3N0d7e3u0tbWVehQAAAAAAKDMlE0QAQAAAAAAGChBBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvbIJIoVCIerq6qKxsbHUowAAAAAAAGWmbIJIc3NztLe3R1tbW6lHAQAAAAAAykzZBBEAAAAAAICBEkQAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPTKJogUCoWoq6uLxsbGUo8CAAAAAACUmbIJIs3NzdHe3h5tbW2lHgUAAAAAACgzZRNEAAAAAAAABkoQAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASK9sgkihUIi6urpobGws9SgAAAAAAECZKZsg0tzcHO3t7dHW1lbqUQAAAAAAgDJTNkEEAAAAAABgoAQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvbIJIoVCIerq6qKxsbHUowAAAAAAAGWmbIJIc3NztLe3R1tbW6lHAQAAAAAAykzZBBEAAAAAAICBEkQAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvZIEkSuuuCI++MEPxmc/+9lSHB4AAAAAADjJlCSI3HDDDfHTn/60FIcGAAAAAABOQiUJIhdeeGFUVlaW4tAAAAAAAMBJqN9BZPPmzTFv3ryoqamJioqKWL9+/TveUygUYurUqTFmzJiYOXNmbN26dTBmBQAAAAAAGJB+B5Genp6or6+PQqFw1NfXrl0bLS0tsXz58ti+fXvU19fH3LlzY+/evQMa8NChQ9Hd3X3EAwAAAAAAoD/6HUQuvfTSWLFiRVxxxRVHff3OO++Ma6+9NhYvXhx1dXVx7733xqmnnhr33XffgAa89dZbY/z48X2P2traAX0OAAAAAABw8hrUe4i8/vrrsW3btmhqavq/A4wYEU1NTbFly5YBfebSpUujq6ur79HR0TFY4wIAAAAAACeJUYP5Yfv374/Dhw9HdXX1Edurq6vj6aef7nve1NQUf/rTn6KnpyemTJkSjzzySMyaNeuonzl69OgYPXr0YI4JAAAAAACcZAY1iByv3/72t6U4LAAAAAAAcJIa1K/MmjhxYowcOTL27NlzxPY9e/bEpEmTBvNQAAAAAAAAx21Qg8gpp5wS06ZNi40bN/Zt6+3tjY0bNx7zK7EAAAAAAACGWr+/MuvgwYPx7LPP9j3ftWtX7NixIyZMmBBnnnlmtLS0xKJFi2L69OkxY8aMWLlyZfT09MTixYsHdXAAAAAAAIDj1e8g8tRTT8WcOXP6nre0tERExKJFi2L16tWxYMGC2LdvXyxbtiw6OzujoaEhWltb33GjdQAAAAAAgOFSUSwWi6Ueoj+6u7tj/Pjx0dXVFVVVVaUeBwAAKKEXb/nDgPabctsnB3kSAACgVI63GwzqPUQAAAAAAABORIIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApFc2QaRQKERdXV00NjaWehQAAAAAAKDMlE0QaW5ujvb29mhrayv1KAAAAAAAQJkZVeoB+qtYLEZERHd3d4knAQAASu3AoZ4B7ef3CQAAyOPtn+/f7gfHUnZB5MCBAxERUVtbW+JJAACAsrWy1AMAAACD7cCBAzF+/Phjvl5RfK9kcoLp7e2N3bt3R2VlZVRUVJR6HChb3d3dUVtbGx0dHVFVVVXqcYCErDPAULPOAEPNOgMMJWsMDJ5isRgHDhyImpqaGDHi2HcKKbsrREaMGBFTpkwp9RiQRlVVlf90gSFlnQGGmnUGGGrWGWAoWWNgcLzblSFvK5ubqgMAAAAAAAyUIAIAAAAAAKQniMBJavTo0bF8+fIYPXp0qUcBkrLOAEPNOgMMNesMMJSsMTD8yu6m6gAAAAAAAP3lChEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBJI6cOBALFmyJM4666wYO3ZsfOITn4i2trZ33efQoUPxrW99K84666wYPXp0TJ06Ne67775hmhgoNwNZZ9asWRP19fVx6qmnxuTJk+Oaa66Jl19+eZgmBk5kmzdvjnnz5kVNTU1UVFTE+vXrj3i9WCzGsmXLYvLkyTF27NhoamqKv//97+/5uYVCIaZOnRpjxoyJmTNnxtatW4foDIAT3VCsM7feems0NjZGZWVlnH766TF//vx45plnhvAsgBPZUP0887bbbrstKioqYsmSJYM7OJxEBBFI6stf/nJs2LAhHnzwwdi5c2dccskl0dTUFP/85z+Puc+VV14ZGzdujJ/85CfxzDPPxEMPPRQf+9jHhnFqoJz0d5154oknYuHChfGlL30p/vKXv8QjjzwSW7dujWuvvXaYJwdORD09PVFfXx+FQuGor99xxx1x1113xb333htPPvlkjBs3LubOnRuvvfbaMT9z7dq10dLSEsuXL4/t27dHfX19zJ07N/bu3TtUpwGcwIZindm0aVM0NzfHH//4x9iwYUO88cYbcckll0RPT89QnQZwAhuKdeZtbW1t8cMf/jDOO++8wR4bTioVxWKxWOohgMH16quvRmVlZfzqV7+Kyy+/vG/7tGnT4tJLL40VK1a8Y5/W1ta46qqr4vnnn48JEyYM57hAGRrIOvP9738/7rnnnnjuuef6tv3gBz+I22+/PV588cVhmRsoDxUVFbFu3bqYP39+RLz115Q1NTVx4403xk033RQREV1dXVFdXR2rV6+Oq6666qifM3PmzGhsbIxVq1ZFRERvb2/U1tbG9ddfH7fccsuwnAtwYhqsdea/7du3L04//fTYtGlTfOpTnxqq8YEyMJjrzMGDB+P888+Pu+++O1asWBENDQ2xcuXKYTgLyMcVIpDQm2++GYcPH44xY8YcsX3s2LHx+OOPH3WfRx99NKZPnx533HFHnHHGGfHRj340brrppnj11VeHY2SgzAxknZk1a1Z0dHTEb37zmygWi7Fnz574xS9+EZdddtlwjAyUsV27dkVnZ2c0NTX1bRs/fnzMnDkztmzZctR9Xn/99di2bdsR+4wYMSKampqOuQ9w8hrIOnM0XV1dERH+yAx4h/+fdaa5uTkuv/zyI/YFBmZUqQcABl9lZWXMmjUrvvvd78Y555wT1dXV8dBDD8WWLVviIx/5yFH3ef755+Pxxx+PMWPGxLp162L//v3x9a9/PV5++eW4//77h/kMgBPdQNaZ2bNnx5o1a2LBggXx2muvxZtvvhnz5s075uXkAG/r7OyMiIjq6uojtldXV/e99t/2798fhw8fPuo+Tz/99NAMCpStgawz/623tzeWLFkSs2fPjnPPPXfQZwTK20DXmYcffji2b9/+nvdrBI6PK0QgqQcffDCKxWKcccYZMXr06Ljrrrvic5/7XIwYcfR/9r29vVFRURFr1qyJGTNmxGWXXRZ33nlnPPDAA64SAY6qv+tMe3t73HDDDbFs2bLYtm1btLa2xj/+8Y/46le/OsyTAwAMvubm5vjzn/8cDz/8cKlHAZLo6OiIG264IdasWfOOq/OBgRFEIKkPf/jDsWnTpjh48GB0dHTE1q1b44033ogPfehDR33/5MmT44wzzojx48f3bTvnnHOiWCz6bn/gqPq7ztx6660xe/bsuPnmm+O8886LuXPnxt133x333XdfvPTSS8M8PVBOJk2aFBERe/bsOWL7nj17+l77bxMnToyRI0f2ax/g5DWQdeb/dd1118Wvf/3r+P3vfx9TpkwZkhmB8jaQdWbbtm2xd+/eOP/882PUqFExatSo2LRpU9x1110xatSoOHz48JDPDdkIIpDcuHHjYvLkyfGvf/0rHnvssfj0pz991PfNnj07du/eHQcPHuzb9re//S1GjBjhB3rgXR3vOvOf//znHVePjBw5MiLeusEgwLGcffbZMWnSpNi4cWPftu7u7njyySdj1qxZR93nlFNOiWnTph2xT29vb2zcuPGY+wAnr4GsMxFv/Qxz3XXXxbp16+J3v/tdnH322cMxLlCGBrLOXHzxxbFz587YsWNH32P69OnxhS98IXbs2NH3+xRw/AQRSOqxxx6L1tbW2LVrV2zYsCHmzJkTH//4x2Px4sUREbF06dJYuHBh3/s///nPx2mnnRaLFy+O9vb22Lx5c9x8881xzTXXxNixY0t1GsAJrL/rzLx58+KXv/xl3HPPPfH888/HE088Ed/4xjdixowZUVNTU6rTAE4QBw8e7PtFP+KtG4/u2LEjXnjhhaioqIglS5bEihUr4tFHH42dO3fGwoULo6amJubPn9/3GRdffHGsWrWq73lLS0v8+Mc/jgceeCD++te/xte+9rXo6enpW6eAk8tQrDPNzc3xs5/9LH7+859HZWVldHZ2Rmdnp68dhpPUYK8zlZWVce655x7xGDduXJx22mnuVQQD5KbqkFRXV1csXbo0XnzxxZgwYUJ85jOfie9973vxvve9LyIiXnrppXjhhRf63v/+978/NmzYENdff31Mnz49TjvttLjyyitjxYoVpToF4ATX33Xm6quvjgMHDsSqVavixhtvjA984ANx0UUXxe23316qUwBOIE899VTMmTOn73lLS0tERCxatChWr14d3/zmN6Onpye+8pWvxL///e+44IILorW19Yjv037uuedi//79fc8XLFgQ+/bti2XLlkVnZ2c0NDREa2vrO25mCpwchmKdueeeeyIi4sILLzziWPfff39cffXVQ3cywAlpKNYZYHBVFH1HBQAAAAAAkJyvzAIAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACC9/wUt71o0Sbl9QgAAAABJRU5ErkJggg==",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# check parameter distribution\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"def plot_param_dist(model):\n",
" plot_dist = {}\n",
" for name, param in model.named_parameters():\n",
" if param.requires_grad:\n",
" plot_dist[name] = param.detach().cpu().numpy().flatten()\n",
" # sample 1000\n",
" # replace nan with 1000\n",
" plot_dist[name] = np.where(np.isnan(plot_dist[name]), 10, plot_dist[name])\n",
" plot_dist[name] = np.random.choice(plot_dist[name], 1000)\n",
" \n",
" plt.figure(figsize=(20, 10))\n",
"\n",
" for idx, (name, dist) in enumerate(plot_dist.items()):\n",
" \n",
" plt.hist(dist, bins=100, density=True)\n",
" \n",
" # log y\n",
" plt.yscale('log')\n",
"\n",
" #plt.legend()\n",
" plt.show()\n",
"\n",
"\n",
"plot_param_dist(d3pm.x0_model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAACxCAYAAADwMnaUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkhklEQVR4nO3dfWxV5R0H8G8LtKDQy4uhtSuFbmNDJ1YHWjvMcNrJiFEYZFPDRlESoyuO0mRTtqHGzVVdNl58gbktwDIZjkRgkKhhRUpMylsRJuIqy5igeMucay+ilI4++2Pzjvvr8f7O755T7mn5fpKb9Nzz9pznvPTJeX739+Q45xyIiIiIIiA32wUgIiIi+hgbJkRERBQZbJgQERFRZLBhQkRERJHBhgkRERFFBhsmREREFBlsmBAREVFksGFCREREkcGGCREREUUGGyZEREQUGT3WMHnqqacwZswYDBw4EBUVFdi1a1dP7YqIiIj6iJyeGCvnueeew+zZs7FixQpUVFRgyZIlWLduHVpaWjBy5Mi063Z1deHYsWMYMmQIcnJywi4aERER9QDnHE6cOIHi4mLk5mb+3qNHGiYVFRW46qqr8OSTTwL4b2Nj1KhRuPfee3H//fenXfftt9/GqFGjwi4SERERnQNHjx5FSUlJxuuH3pVz+vRpNDc3o6qq6v87yc1FVVUVmpqaui3f0dGBRCKR/HCwYyIiot5ryJAhgdYPvWHy3nvv4cyZMygsLEz5vrCwEPF4vNvy9fX1iMViyU9paWnYRSIiIqJzJGgYRtZ/lbNw4UK0t7cnP0ePHs12kYiIiChL+oe9wYsuugj9+vVDa2tryvetra0oKirqtnx+fj7y8/PDLgYRERH1QqG/McnLy8OECRPQ0NCQ/K6rqwsNDQ2orKwMe3dERETUh4T+xgQA6urqUF1djYkTJ+Lqq6/GkiVLcPLkSdxxxx09sTsiIiLqI3qkYXLrrbfiH//4Bx544AHE43FcccUVePHFF7sFxGYqFoulTI8ZMyZluq2tLe20XN7L3//+95TpoUOHpl1e7kNb3lpmKezy+NmnRtuHnN/Y2Jh2+dGjR6dMa8dkPWY/x+un3oIIet3I61TS6shrGeu1L+dfccUVKdOrV69Ou355eXnKdNDrMAzaedGm9+3bl3a+9VqWMnl+aNeOVobrrrsu7fr79+/vXtCzyPvZWj5NJs+8ntiGZX3rvZYJ6zFp18Fbb70VvFCKHmmYAMC8efMwb968nto8ERER9UFZ/1UOERER0cfYMCEiIqLI6LGunJ40Z86clGnZByb7dyWtf9iL1hdonQ47psS6Pz/kNoLGN1hp/eja/uT6Wr+/F61/dvr06SnT27ZtU7eZbvtyexs2bEi7vGStA6D7eZSxBNryWoyKRt7Psoxye2Fcd1o9yjgZLYZEK4Oclscsz7MWB5dJbELQY5bXhbzWtRgTuX1ZHnlMPfG8sT77ZZnledfqVIvTsZ4Trc68aM9ROV/uU57ncxFjwjcmREREFBlsmBAREVFksGFCREREkdErY0y0/l2t3172K3rFBch+N61vUOv7C9qXKGnHILcnj8erb1UrszV3gpyWZdLymGjry+tAq0NZfq9+fFmPchl5LVn7+rXzoPWzy/1L2nXgtf6SJUvSlkETNPeDLLM1ZkWLffCKmdHi0mS9yTJo16IW46U9w7T9yxgV7br1KoPcpiyTtg9rPJVWHu260eK5tOeF1z7ktbFq1aq0+9TOq/UZqj2ztPX93CvyGLU4GTlf7lOLJQoD35gQERFRZLBhQkRERJHBhgkRERFFRo5zzmW7EGdLJBLdxsKRtLE1tH45P3EB1rEqgo55YM0zEnTsjUyO2Vomrf+zvb097fas51krTyaxENZxZIKOgWLVE/vX1tHiK6x90pMnT06Ztsa4+MnVYt2G9Txr2wt63q3PND9lsj6ztBw7Wn4L65hI1ntP8tq+fObI/zXW82ytMxnDEjTXU9jPE0DPoeMnxqS9vR0FBQUZl4FvTIiIiCgy2DAhIiKiyGDDhIiIiCKDDRMiIiKKjF4Z/Dp69Oi0860BTF7zMwkus5TBGrgVNDivJwb5ClonWhCVdp6tMgkA1urVGkyn7U8KGkTdE8Gv2vKSFhQp6yzovXcuaMccdtC0dft+BA28lLTzrAU5B73O/NShFritCZpMMOwfGGj/J/ywlsnPIH4MfiUiIqI+gw0TIiIiigw2TIiIiCgyeuUgfkEH0Mok3kJLmGTtE9b69eTAS9bkQpnEg1gHqLKy9mVqsQbWAa606wTQ612yDsKnzdeSyFmPORPW5GCyzmTiLY024KRkvTe9yt/TMWDatWaNo9GSF2aS0E0bQE4egzYYnMYaw5bJ/avRyqwdozXGI2hskfX/ip9ltGvtoYceSjt9LvCNCREREUUGGyZEREQUGWyYEBERUWT0yhgTrd9P60OTfZlefeLWmJJMcqWk2582MJp2jJnkBLDGL1jjH4LmPQmaE8RPPMa2bdvSLmOtV61PO+ycIXJ9beA1r21a+8XlMdbW1qZML1iwoHtB05DnQNLuRT/5MayxBFq8hSyzNZ+E9RxosUpe14k8hqDxU3K+lt9CLi+vzSVLlqRd3hpL5HUOtPtBqyM5CJ8sc9DYQWvMiZ94Kuv9Io9J1tnSpUu77SNsfGNCREREkcGGCREREUUGGyZEREQUGb0yxiRoX6PsV/Tqj7X27cnfest+di1nQCZ9h5b1pUzGbLDGtQSNMdHqwBpbJNeX/b1erH35WmyQFp9hFcb4IUHH75Dryz5qjTX3i7a+n9wOWhyLJLdpjYPR5gcdcyWT3C1Bc+gELbP1XrHmVfKqA7lP6zHJ57q1DuR1EzR2MZOcXNo2tPipc4FvTIiIiCgy2DAhIiKiyDA3TLZv346bb74ZxcXFyMnJ6dYt4pzDAw88gIsvvhiDBg1CVVUVDh06FFZ5iYiIqA8zx5icPHkS5eXluPPOOzFjxoxu8x9//HEsW7YMq1evRllZGRYtWoQpU6bg4MGDGDhwYCiFlqz5N2Qfmp/ffmu5T6z96tr2wx6Xxs/2tWWs/alB4yes/a+y/JI1JsVrm/KYtPE9tDLIHAEyT4I1H0YmcQHaMvIYrf3kWn4La04Q7bqU25N16rVNydovr+UhCjoGU9AcQ372Yc1fE5R2DNp1Zr03vLah3Y9yeeu1r+VNseY10Xgtrz2j5D5lGTOp56DMDZOpU6di6tSpnvOcc1iyZAl+9KMfYdq0aQCA3/72tygsLMSGDRtw2223BSstERER9WmhxpgcPnwY8XgcVVVVye9isRgqKirQ1NTkuU5HRwcSiUTKh4iIiM5PoTZM4vE4AKCwsDDl+8LCwuQ8qb6+HrFYLPkZNWpUmEUiIiKiXiTreUwWLlyIurq65HQikVAbJ9p4IFo/YCaxE1q/e9Df+Gu/FbfGMmjHOGfOnG7r+Mnvko52XuT2tdgDaz+/dsxaXABgH9tGy8EhyeXlebfGGmnHqO3fixZTYs3xo7HmBLHWsR/WZ4Q1BkSL/dFimYLmOfIqk1bvYZ9nub4WyyC3L3OIaM95P2XQ6lEbO0c+4+R8OS1jSuR5l/mwrM9gP+MDWfcRdmyRH6G+MSkqKgIAtLa2pnzf2tqanCfl5+ejoKAg5UNERETnp1AbJmVlZSgqKkJDQ0Pyu0QigZ07d6KysjLMXREREVEfZO7K+eCDD/DXv/41OX348GHs27cPw4cPR2lpKWpra/GTn/wEY8eOTf5cuLi4uNvrJCIiIiIpxznnLCts27YNX/nKV7p9X11djVWrVsE5hwcffBDPPPMM2tracO211+Lpp5/G5z73OV/bTyQSiMViaZcpLy9PO1+L99DGN/FaRusD1n6Tbx3XIWgsg9Yv6LW+dcwSrV61etZiTOR1YK1j2Z/rJ6+KNUZDy19h7aeX/ehye9Z+eW2+13daLgfreCP79+/vts+zjR49Ou36Vpn0iVvjWHo6/kkrj/b88RLGuCvptqfdz/K5HTRuRnseeLHGgFhjzqx1rD0jrXFyXnFDQf83Sdp5BoD29vZAYRnmNybXXXcd0rVlcnJy8PDDD+Phhx/OuFBERER0fuJYOURERBQZbJgQERFRZGQ9j0kmtLEwtH4/LQeJ13dB+xq12ARrnhKZh8RrPJB02/PavhZTosUryOXlMcn1tb7KoPETWp4Gr355a1yNnxiOs2ljY8gYEut4QVr5vVjzlmi5Vqx91kHzZWSSh0E7b0H7/q3PHGtMidyfjE3yGrtLiyWS503uU+5De+ZotPPq5zl9Nj/Xodym9dq33u9S2HEzfnJyaceg5WaRdeInxiQovjEhIiKiyGDDhIiIiCKDDRMiIiKKDHMek57mJ4/J5MmTU6a1/tqgeRG8tqH17Wl9zHL9oONOBI1F8CpT0HgGrb9048aN3cpwNpnfImjOkEzy1wQdH0T2y8txKrQ6sl7bYVz71tgi7brQ8phoeYkkuT8Zt6PFFvmh5QWxxpRpeZA01txMfuLmZJm05bX73Xqegz5TJa/xvyQtHlHSxvsKOzeM9f+IFhPjZx/WODrtuQ0Ez2PCNyZEREQUGWyYEBERUWSwYUJERESR0SvzmFj7PjPJYxI03sLaLy+F/Xv5TMZH0MaR0Oo5kzFLzqbFX2jT8jzL/uFMxtaw9hlrY9v4iXuxkHXuJ79FJnlAzib79uXyWuyBNWZF8jNOjCSvZUnru7fm1LHmdtFiCSQ/97e2jPZcDCNWLx3tOtByPVnjR7yWkfeflqvFep6sOYCsz1Q/512brz03zwW+MSEiIqLIYMOEiIiIIoMNEyIiIoqMPpHHxJozwM+4MVouB2sMirWfz9onbc1v4acM1t+7W2ljLsjzLM+B7M+15nrwquNMxpqx7FMTNJdDJtehNcbEei03NjamXV7mt9DuPY2f8gbdh7bPnjhP6eZnksfEmjNHi4+w3s/atW6NgfFDW8caO2SNh9KeJ1oOEW1/XjEuWgyYFnsn47FWr17dbR8S85gQERFRn8GGCREREUUGGyZEREQUGWyYEBERUWT0iQRrkhbsqi0P2INdwx4AT7JuT/KzfNiBljLoSgZZacFyUtA6kgPoeSUbswbcWROaWQci1FivUy/aebYm3rImidMC/iRroKjX8QUNZtcS5WnHYE1iJfevDS7n57xr16KkBQxr97P1PMnty+eJvLf8BCAHfY5atydZk1Rq/ATny4BYazLQbOAbEyIiIooMNkyIiIgoMtgwISIiosjolTEmWgI1SYsb8BNvofXHav10WkxK2P16YSQfsiZkCjshmzWmRBuYTfZJ90SCNTnoV9CkUNY6yyTpnZZIy5poS5ZBiz2w1okWF+BnYLWg50Hbh5b4SsaIWGNerDEpXtuwxr1pz0QtkZ58ZmrPcbl/GSMmZZJkzjpfqyPt/4QWUybrRLI+EwE9Nkd7bgYdWDQTfGNCREREkcGGCREREUUGGyZEREQUGb0yxkT2A1r73bT4ED/LaPEYWg4PrYzafOvv5eX+/fS/hv37dmu8hjWfhYzvsPbfem1Tm6/F1Vj7pK05RbQ4AHkdyDryWkfS+v6D0gak0+5FWR5573nFmGi0e0Hrd5f7lMck52sxZ1reFMmrfFo9a/UUNF+Ndu1q86287ndr3Jyk3V9anWjxW3J9a3m8yJgS7bzL+Kds4BsTIiIiigxTw6S+vh5XXXUVhgwZgpEjR2L69OloaWlJWebUqVOoqanBiBEjMHjwYMycOROtra2hFpqIiIj6JlPDpLGxETU1NdixYwe2bNmCzs5O3HjjjTh58mRymQULFmDTpk1Yt24dGhsbcezYMcyYMSP0ghMREVHfY4oxefHFF1OmV61ahZEjR6K5uRlf/vKX0d7ejt/85jdYs2YNrr/+egDAypUrcckll2DHjh245pprQim01gem9e9qeRkAPQ+BZO2vteZBkGWWfZtyvlYHfspkzWOQyT7TseaCkLQ68WLtB9fiH6x9xtr+tetEkrFFXte63IfWbx52LJIWM2bt19fGJ/KzD4113Bc/8U3pti+vdWt+DK8yafvUjlHW6/79+9MuL4/Ba6yqs1njbPzEHmr3lzUGRYuD0Z4Hsk60543kJw4naPxiNsbOCRRj0t7eDgAYPnw4AKC5uRmdnZ2oqqpKLjNu3DiUlpaiqakpyK6IiIjoPJDxr3K6urpQW1uLSZMm4bLLLgMAxONx5OXldWuBFRYWIh6Pe26no6MDHR0dyelEIpFpkYiIiKiXy/iNSU1NDQ4cOIC1a9cGKkB9fT1isVjyM2rUqEDbIyIiot4rozcm8+bNw+bNm7F9+3aUlJQkvy8qKsLp06fR1taW8taktbUVRUVFnttauHAh6urqktOJREJtnMhcDFp/sdZvmMlv/q19lVq8hBa7EHYMi59YA0n7zb0W7xCUVofaeda2B+j13tOxCdY4He2YMxk/xJpzRyujNlaOFqMi61yeIy1vide9Yq1Xa34b7f7TYsQkGY+h5cPwc11q+9Rij6x5RrSYEuszNZNj1u5f6/2tra8917V7TYtJ0ZbPRNh5ijJhemPinMO8efOwfv16bN26FWVlZSnzJ0yYgAEDBqChoSH5XUtLC44cOYLKykrPbebn56OgoCDlQ0REROcn0xuTmpoarFmzBhs3bsSQIUOScSOxWAyDBg1CLBbD3LlzUVdXh+HDh6OgoAD33nsvKisrQ/tFDhEREfVdpobJ8uXLAXR/pbdy5crk69TFixcjNzcXM2fOREdHB6ZMmYKnn346lMISERFR35bjnHPZLsTZEokEYrFY2mUmT56cdr61P9gP2Q8u+w5l3IvWby77xbXxDLSYFet4QT3BGtfS2NiYdnvl5eUp09Y8DJnkFLGOz6PFFgSNVbCO5aHFW3itb43JsuYZWb16dfeCnmX06NFp59fW1qZMa7EKkp94qqBjHmn7tMYOWc9JJrRj1JaX19rSpUvTri+f29Zxo4LeG5msI49RG5/Legza8tq9pcWoAPZrUcuZoz23gf+mEgkSlsGxcoiIiCgy2DAhIiKiyGDDhIiIiCKjV8aYaH3SWj+i7LN+6KGH1G0EzXtgzS+RSf9pOn5yjFh/v671l2rLa/ktpk2blnZ9ydpnHYag8RjWcSqs141WHi9aPIW1DNbzbB3rKox4DO3asF5LQc+LNi6M5GesHCt5DFoODS2WSMaMaax1lMkYL0Gfs9b1g+7Pem/6WcYa96LdzwBjTIiIiKgPYcOEiIiIIoMNEyIiIoqMXhljYs1vkUn8hpaPImgeEbm+/D160DFYrH3mfrahjV1hzc2wf//+tMtrsUQa6zkC7LFEQVn7dyVr7gmv47PePzKXghbfEPQ8B43D8UNbJ2i8kjV+S8tXoV3bXrlb5HnSxnHRxvOR+9RiD7TzHDQux+uYNVrcilZnQePygsZHhZHLRdLKoN3PAGNMiIiIqA9hw4SIiIgigw0TIiIiigzTIH5RZe2b9NMHrfUdyv5XmXtB63e3jsWhjdUj96fxOj5rPIP19/BWQXPHyDrR8jB4rRO0TzhoXhJr7pZMWK9N67UWdP8ard/fqw6t1651rBvJet79jIFiFTR/hbY9P/kt0gk6Vo7kdR3JfVjHqtGEXceaMOLmrGX2E2MSFN+YEBERUWSwYUJERESRwYYJERERRUavjDHR4jvmzJmTMi37a2UfWia/+V+yZEnaMkpamYPGHli3nwlrn6+ctsYmhDHuS7r9+6kT6/g/Whm1/l5rngRrnfjpd5e0Ywwag2KNTbJey173t3XcFXl/aXEsQeMl5DNs1apVSMfPvWaNtdOek/IZo7HGd8hjkHUi5/t5HmjjpMk6keddlln7PxA0P441T4mfnFzac1D7X3Iu8I0JERERRQYbJkRERBQZbJgQERFRZLBhQkRERJHRKwfxmzZtWsq0FmBoDUD0WiaToKOzyeReWsImLZDLmgwtkzrQtqmVSQsIbGxsTLu9xYsXp0zLAMCwk1x5LaPRAj+DDgqYSRKpdOv7qbOgwahyWku8NX/+/JRpeW9o+9cC072OWX6nBXYGHfgsaCI9beBErTxetGeENeDXOlij3L8MbtUCU7Vz5nXdWIOeteemlnRS27/1eWNNlua1D2ugttym9twGOIgfERER9SFsmBAREVFksGFCREREkdErE6xZE/9o/b1+BpvTkv8ETS4m+x7lMcn+Xpk0zjromFffpuzjletoA4tp9WrtT5UxJVryMS0pnp9EYFq/u6SdV61OrLFEWv+w1ifeE/FVWiyCFmMil9cSaWkD3snz7lV+GfMly6DN1+pdez4EHSxOri/L63WvytgbOW2NNZD1rMWYyP1Z46Os17rX9rVltFhA7X603q8aa2I+r2s9aKLKMJJzWvGNCREREUUGGyZEREQUGWyYEBERUWT0yjwm5eXlaedb+2/9xFtYc2hY+/q1eAgtpiTowGleZdL6L7XBobQBsZYuXZq2jPI8a33M1vMuywPoA1Zpxxw0xkOLt5DXQdBzBOi5GOQ6Yee3kHlMtHguSTtmP33k8pitcWphD+JnzXviJ8+JNfeRxhpLJPOYaDEk2nnXyutnUL+gzxTtWrPmq9Hi6LTyeNWJ9dqT5PyNGzemXR5gHhMiIiLqQ0wNk+XLl+Pyyy9HQUEBCgoKUFlZiRdeeCE5/9SpU6ipqcGIESMwePBgzJw5E62traEXmoiIiPomU8OkpKQEjz76KJqbm7Fnzx5cf/31mDZtGl5//XUAwIIFC7Bp0yasW7cOjY2NOHbsGGbMmNEjBSciIqI+yAU0bNgw9+tf/9q1tbW5AQMGuHXr1iXnvfHGGw6Aa2pq8r299vZ2B4Affvjhhx9++OmFn/b29kDtioxjTM6cOYO1a9fi5MmTqKysRHNzMzo7O1FVVZVcZty4cSgtLUVTU9MnbqejowOJRCLlQ0REROcnc8Pktddew+DBg5Gfn4+7774b69evx6WXXop4PI68vLxuEbyFhYWIx+OfuL36+nrEYrHkZ9SoUeaDICIior7B3DD5/Oc/j3379mHnzp245557UF1djYMHD2ZcgIULF6K9vT35OXr0aMbbIiIiot7NPFZOXl4ePvvZzwIAJkyYgN27d2Pp0qW49dZbcfr0abS1taW8NWltbUVRUdEnbi8/Px/5+fn2khMREVGfEziPSVdXFzo6OjBhwgQMGDAADQ0NyXktLS04cuQIKisrg+6GiIiIzgOmNyYLFy7E1KlTUVpaihMnTmDNmjXYtm0bXnrpJcRiMcydOxd1dXUYPnw4CgoKcO+996KyshLXXHNNT5WfiIiI+hBTw+T48eOYPXs23n33XcRiMVx++eV46aWX8NWvfhUAsHjxYuTm5mLmzJno6OjAlClT8PTTT5sK5KKVIZ+IiIgMgv4fj9xYOW+//TZ/mUNERNRLHT16FCUlJRmvH7mGSVdXF44dOwbnHEpLS3H06NFAgwGd7xKJBEaNGsV6DIB1GBzrMBysx+BYh8F9Uh0653DixAkUFxcjNzfzEFbzr3J6Wm5uLkpKSpKJ1j4el4eCYT0GxzoMjnUYDtZjcKzD4LzqMBaLBd4uRxcmIiKiyGDDhIiIiCIjsg2T/Px8PPjgg0y+FhDrMTjWYXCsw3CwHoNjHQbX03UYueBXIiIiOn9F9o0JERERnX/YMCEiIqLIYMOEiIiIIoMNEyIiIoqMyDZMnnrqKYwZMwYDBw5ERUUFdu3ale0iRVZ9fT2uuuoqDBkyBCNHjsT06dPR0tKSssypU6dQU1ODESNGYPDgwZg5cyZaW1uzVOLoe/TRR5GTk4Pa2trkd6xDf9555x1861vfwogRIzBo0CCMHz8ee/bsSc53zuGBBx7AxRdfjEGDBqGqqgqHDh3KYomj5cyZM1i0aBHKysowaNAgfOYzn8GPf/zjlPFHWIeptm/fjptvvhnFxcXIycnBhg0bUub7qa/3338fs2bNQkFBAYYOHYq5c+figw8+OIdHkX3p6rGzsxP33Xcfxo8fjwsvvBDFxcWYPXs2jh07lrKNMOoxkg2T5557DnV1dXjwwQexd+9elJeXY8qUKTh+/Hi2ixZJjY2NqKmpwY4dO7BlyxZ0dnbixhtvxMmTJ5PLLFiwAJs2bcK6devQ2NiIY8eOYcaMGVksdXTt3r0bv/zlL3H55ZenfM861P3rX//CpEmTMGDAALzwwgs4ePAgfv7zn2PYsGHJZR5//HEsW7YMK1aswM6dO3HhhRdiypQpOHXqVBZLHh2PPfYYli9fjieffBJvvPEGHnvsMTz++ON44oknksuwDlOdPHkS5eXleOqppzzn+6mvWbNm4fXXX8eWLVuwefNmbN++HXfddde5OoRISFePH374Ifbu3YtFixZh7969eP7559HS0oJbbrklZblQ6tFF0NVXX+1qamqS02fOnHHFxcWuvr4+i6XqPY4fP+4AuMbGRuecc21tbW7AgAFu3bp1yWXeeOMNB8A1NTVlq5iRdOLECTd27Fi3ZcsWN3nyZDd//nznHOvQr/vuu89de+21nzi/q6vLFRUVuZ/97GfJ79ra2lx+fr77/e9/fy6KGHk33XSTu/POO1O+mzFjhps1a5ZzjnWoAeDWr1+fnPZTXwcPHnQA3O7du5PLvPDCCy4nJ8e9884756zsUSLr0cuuXbscAPfWW28558Krx8i9MTl9+jSam5tRVVWV/C43NxdVVVVoamrKYsl6j/b2dgDA8OHDAQDNzc3o7OxMqdNx48ahtLSUdSrU1NTgpptuSqkrgHXo1x//+EdMnDgR3/jGNzBy5EhceeWV+NWvfpWcf/jwYcTj8ZR6jMViqKioYD3+z5e+9CU0NDTgzTffBADs378fr7zyCqZOnQqAdWjlp76ampowdOhQTJw4MblMVVUVcnNzsXPnznNe5t6ivb0dOTk5GDp0KIDw6jFyg/i99957OHPmDAoLC1O+LywsxF/+8pcslar36OrqQm1tLSZNmoTLLrsMABCPx5GXl5e8eD5WWFiIeDyehVJG09q1a7F3717s3r272zzWoT9/+9vfsHz5ctTV1eEHP/gBdu/eje9+97vIy8tDdXV1sq687m/W43/df//9SCQSGDduHPr164czZ87gkUcewaxZswCAdWjkp77i8ThGjhyZMr9///4YPnw46/QTnDp1Cvfddx9uv/325EB+YdVj5BomFExNTQ0OHDiAV155JdtF6VWOHj2K+fPnY8uWLRg4cGC2i9NrdXV1YeLEifjpT38KALjyyitx4MABrFixAtXV1VkuXe/whz/8Ac8++yzWrFmDL3zhC9i3bx9qa2tRXFzMOqRI6OzsxDe/+U0457B8+fLQtx+5rpyLLroI/fr16/Zrh9bWVhQVFWWpVL3DvHnzsHnzZrz88ssoKSlJfl9UVITTp0+jra0tZXnW6f81Nzfj+PHj+OIXv4j+/fujf//+aGxsxLJly9C/f38UFhayDn24+OKLcemll6Z8d8kll+DIkSMAkKwr3t+f7Hvf+x7uv/9+3HbbbRg/fjy+/e1vY8GCBaivrwfAOrTyU19FRUXdflzx73//G++//z7rVPi4UfLWW29hy5YtybclQHj1GLmGSV5eHiZMmICGhobkd11dXWhoaEBlZWUWSxZdzjnMmzcP69evx9atW1FWVpYyf8KECRgwYEBKnba0tODIkSOs0/+54YYb8Nprr2Hfvn3Jz8SJEzFr1qzk36xD3aRJk7r9VP3NN9/E6NGjAQBlZWUoKipKqcdEIoGdO3eyHv/nww8/RG5u6qO5X79+6OrqAsA6tPJTX5WVlWhra0Nzc3Nyma1bt6KrqwsVFRXnvMxR9XGj5NChQ/jTn/6EESNGpMwPrR4zCNbtcWvXrnX5+flu1apV7uDBg+6uu+5yQ4cOdfF4PNtFi6R77rnHxWIxt23bNvfuu+8mPx9++GFymbvvvtuVlpa6rVu3uj179rjKykpXWVmZxVJH39m/ynGOdejHrl27XP/+/d0jjzziDh065J599ll3wQUXuN/97nfJZR599FE3dOhQt3HjRvfnP//ZTZs2zZWVlbmPPvooiyWPjurqavepT33Kbd682R0+fNg9//zz7qKLLnLf//73k8uwDlOdOHHCvfrqq+7VV191ANwvfvEL9+qrryZ/LeKnvr72ta+5K6+80u3cudO98sorbuzYse7222/P1iFlRbp6PH36tLvllltcSUmJ27dvX8r/mo6OjuQ2wqjHSDZMnHPuiSeecKWlpS4vL89dffXVbseOHdkuUmQB8PysXLkyucxHH33kvvOd77hhw4a5Cy64wH3961937777bvYK3QvIhgnr0J9Nmza5yy67zOXn57tx48a5Z555JmV+V1eXW7RokSssLHT5+fnuhhtucC0tLVkqbfQkEgk3f/58V1pa6gYOHOg+/elPux/+8IcpD3/WYaqXX37Z8xlYXV3tnPNXX//85z/d7bff7gYPHuwKCgrcHXfc4U6cOJGFo8medPV4+PDhT/xf8/LLLye3EUY95jh3VjpBIiIioiyKXIwJERERnb/YMCEiIqLIYMOEiIiIIoMNEyIiIooMNkyIiIgoMtgwISIioshgw4SIiIgigw0TIiIiigw2TIiIiCgy2DAhIiKiyGDDhIiIiCKDDRMiIiKKjP8AhoB1uxYGu3IAAAAASUVORK5CYII=",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126],\n",
" [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n",
" 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116]],\n",
" device='cuda:0')\n"
]
}
],
"source": [
"print(d3pm.q_mats[200])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1lklEQVR4nO3de3xU9Z3/8XcSkkmCTGKwSUgJmK2tJAhyq2TqpaghKU1dL+mFNsWsoq40tCbZBWR/ELmoESx3I5SKxD5KWnG3WAUKGYKAlHALRrlYtCs1tjiT3WIYAZkMyfz+6CNnGbkOJhm+w+v5eMzj4Zzzme/5fDIxvB/nzEki/H6/XwAAAAaJDHUDAAAAwSLAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACM0y3UDXSWtrY2HT58WD169FBERESo2wEAABfB7/fr008/VVpamiIjz32eJWwDzOHDh5Wenh7qNgAAwCX46KOP1Lt373PuD9sA06NHD0n/+ALY7fYOW9fn86mmpka5ubmKjo7usHUvZ1fazMwb3pg3vDGv+Twej9LT061/x88lbANM+2Uju93e4QEmPj5edrs9bL5ZLuRKm5l5wxvzhjfmDR8X+vgHH+IFAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAME63UDdgqhumrZe39fx/6rsj/eWZ/C47FgAAlzvOwAAAAOMEFWBaW1s1depUZWRkKC4uTl/5ylc0c+ZM+f1+q8bv96u8vFy9evVSXFyccnJy9P777wesc+TIERUWFsputysxMVFjx47VsWPHAmreeecd3XrrrYqNjVV6erpmz579BcYEAADhJKgAM2vWLC1evFjPPfec3n33Xc2aNUuzZ8/WokWLrJrZs2dr4cKFWrJkiXbs2KHu3bsrLy9PJ0+etGoKCwu1f/9+OZ1OrV69Wlu2bNEjjzxi7fd4PMrNzVXfvn1VX1+vZ599VtOmTdPSpUs7YGQAAGC6oD4Ds23bNt19993Kz//H5zGuvfZa/eY3v9HOnTsl/ePsy/z58zVlyhTdfffdkqRf/epXSklJ0auvvqrRo0fr3Xff1bp167Rr1y4NGzZMkrRo0SJ9+9vf1s9//nOlpaVpxYoVamlp0YsvvqiYmBj1799fDQ0Nmjt3bkDQAQAAV6agAsw3vvENLV26VO+9956+9rWv6e2339bWrVs1d+5cSdKhQ4fkcrmUk5NjvSYhIUHDhw9XXV2dRo8erbq6OiUmJlrhRZJycnIUGRmpHTt26N5771VdXZ1uu+02xcTEWDV5eXmaNWuWPvnkE1199dVn9Ob1euX1eq3nHo9HkuTz+eTz+YIZ87za17JF+i9Q2bE6coZLPXYoe+hKzBvemDe8Ma/5LnaWoALM448/Lo/Ho379+ikqKkqtra166qmnVFhYKElyuVySpJSUlIDXpaSkWPtcLpeSk5MDm+jWTUlJSQE1GRkZZ6zRvu9sAaaiokLTp08/Y3tNTY3i4+ODGfOizBzW1uFrns/atWu79Hhn43Q6Q91Cl2Le8Ma84Y15zXXixImLqgsqwKxcuVIrVqxQdXW1dVmnpKREaWlpKioquqRGO8rkyZNVVlZmPfd4PEpPT1dubq7sdnuHHcfn88npdGrq7kh527ruNup90/K67Fif1z7zyJEjFR0dHbI+ugrzhjfmDW/Ma772KygXElSAmTBhgh5//HGNHj1akjRgwAB9+OGHqqioUFFRkVJTUyVJbrdbvXr1sl7ndrs1aNAgSVJqaqqampoC1j116pSOHDlivT41NVVutzugpv15e83n2Ww22Wy2M7ZHR0d3ypvqbYvo0t8Dczl8Y3bW1/JyxbzhjXnDG/Oa62LnCOoupBMnTigyMvAlUVFRamv7x+WUjIwMpaamqra21trv8Xi0Y8cOORwOSZLD4VBzc7Pq6+utmo0bN6qtrU3Dhw+3arZs2RJwHczpdOr6668/6+UjAABwZQkqwNx111166qmntGbNGv3lL3/RqlWrNHfuXN17772SpIiICJWUlOjJJ5/Ua6+9pr179+r+++9XWlqa7rnnHklSZmamvvWtb+nhhx/Wzp079cc//lHjx4/X6NGjlZaWJkn60Y9+pJiYGI0dO1b79+/Xyy+/rAULFgRcIgIAAFeuoC4hLVq0SFOnTtVPfvITNTU1KS0tTf/6r/+q8vJyq2bixIk6fvy4HnnkETU3N+uWW27RunXrFBsba9WsWLFC48eP15133qnIyEgVFBRo4cKF1v6EhATV1NSouLhYQ4cO1TXXXKPy8nJuoQYAAJKCDDA9evTQ/PnzNX/+/HPWREREaMaMGZoxY8Y5a5KSklRdXX3eYw0cOFBvvvlmMO0BAIArBH8LCQAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwTlAB5tprr1VERMQZj+LiYknSyZMnVVxcrJ49e+qqq65SQUGB3G53wBqNjY3Kz89XfHy8kpOTNWHCBJ06dSqgZtOmTRoyZIhsNpuuu+46VVVVfbEpAQBAWAkqwOzatUsff/yx9XA6nZKk733ve5Kk0tJSvf7663rllVe0efNmHT58WPfdd5/1+tbWVuXn56ulpUXbtm3TSy+9pKqqKpWXl1s1hw4dUn5+vm6//XY1NDSopKREDz30kNavX98R8wIAgDDQLZjiL33pSwHPn3nmGX3lK1/RN7/5TR09elTLli1TdXW17rjjDknS8uXLlZmZqe3btys7O1s1NTU6cOCANmzYoJSUFA0aNEgzZ87UpEmTNG3aNMXExGjJkiXKyMjQnDlzJEmZmZnaunWr5s2bp7y8vA4aGwAAmCyoAHO6lpYW/frXv1ZZWZkiIiJUX18vn8+nnJwcq6Zfv37q06eP6urqlJ2drbq6Og0YMEApKSlWTV5ensaNG6f9+/dr8ODBqqurC1ijvaakpOS8/Xi9Xnm9Xuu5x+ORJPl8Pvl8vksd8wzta9ki/R22ZjDHDYX2Y4eyh67EvOGNecMb85rvYme55ADz6quvqrm5Wf/yL/8iSXK5XIqJiVFiYmJAXUpKilwul1Vzenhp39++73w1Ho9Hn332meLi4s7aT0VFhaZPn37G9pqaGsXHxwc934XMHNbW4Wuez9q1a7v0eGfTfsnwSsG84Y15wxvzmuvEiRMXVXfJAWbZsmUaNWqU0tLSLnWJDjV58mSVlZVZzz0ej9LT05Wbmyu73d5hx/H5fHI6nZq6O1LetogOW/dC9k0L3eWz9plHjhyp6OjokPXRVZg3vDFveGNe87VfQbmQSwowH374oTZs2KDf/e531rbU1FS1tLSoubk54CyM2+1WamqqVbNz586AtdrvUjq95vN3Lrndbtnt9nOefZEkm80mm812xvbo6OhOeVO9bRHytnZdgLkcvjE762t5uWLe8Ma84Y15zXWxc1zS74FZvny5kpOTlZ+fb20bOnSooqOjVVtba207ePCgGhsb5XA4JEkOh0N79+5VU1OTVeN0OmW325WVlWXVnL5Ge037GgAAAEEHmLa2Ni1fvlxFRUXq1u3/TuAkJCRo7NixKisr0xtvvKH6+no98MADcjgcys7OliTl5uYqKytLY8aM0dtvv63169drypQpKi4uts6ePProo/rggw80ceJE/elPf9Lzzz+vlStXqrS0tINGBgAApgv6EtKGDRvU2NioBx988Ix98+bNU2RkpAoKCuT1epWXl6fnn3/e2h8VFaXVq1dr3Lhxcjgc6t69u4qKijRjxgyrJiMjQ2vWrFFpaakWLFig3r1764UXXuAWagAAYAk6wOTm5srvP/stxLGxsaqsrFRlZeU5X9+3b98L3lEzYsQIvfXWW8G2BgAArhD8LSQAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDhBB5i//e1v+vGPf6yePXsqLi5OAwYM0O7du639fr9f5eXl6tWrl+Li4pSTk6P3338/YI0jR46osLBQdrtdiYmJGjt2rI4dOxZQ88477+jWW29VbGys0tPTNXv27EscEQAAhJugAswnn3yim2++WdHR0frDH/6gAwcOaM6cObr66qutmtmzZ2vhwoVasmSJduzYoe7duysvL08nT560agoLC7V//345nU6tXr1aW7Zs0SOPPGLt93g8ys3NVd++fVVfX69nn31W06ZN09KlSztgZAAAYLpuwRTPmjVL6enpWr58ubUtIyPD+m+/36/58+drypQpuvvuuyVJv/rVr5SSkqJXX31Vo0eP1rvvvqt169Zp165dGjZsmCRp0aJF+va3v62f//znSktL04oVK9TS0qIXX3xRMTEx6t+/vxoaGjR37tyAoAMAAK5MQQWY1157TXl5efre976nzZs368tf/rJ+8pOf6OGHH5YkHTp0SC6XSzk5OdZrEhISNHz4cNXV1Wn06NGqq6tTYmKiFV4kKScnR5GRkdqxY4fuvfde1dXV6bbbblNMTIxVk5eXp1mzZumTTz4JOOPTzuv1yuv1Ws89Ho8kyefzyefzBTPmebWvZYv0d9iawRw3FNqPHcoeuhLzhjfmDW/Ma76LnSWoAPPBBx9o8eLFKisr03/8x39o165d+tnPfqaYmBgVFRXJ5XJJklJSUgJel5KSYu1zuVxKTk4ObKJbNyUlJQXUnH5m5/Q1XS7XWQNMRUWFpk+ffsb2mpoaxcfHBzPmRZk5rK3D1zyftWvXdunxzsbpdIa6hS7FvOGNecMb85rrxIkTF1UXVIBpa2vTsGHD9PTTT0uSBg8erH379mnJkiUqKioKvssONHnyZJWVlVnPPR6P0tPTlZubK7vd3mHH8fl8cjqdmro7Ut62iA5b90L2TcvrsmN9XvvMI0eOVHR0dMj66CrMG96YN7wxr/nar6BcSFABplevXsrKygrYlpmZqf/6r/+SJKWmpkqS3G63evXqZdW43W4NGjTIqmlqagpY49SpUzpy5Ij1+tTUVLnd7oCa9uftNZ9ns9lks9nO2B4dHd0pb6q3LULe1q4LMJfDN2ZnfS0vV8wb3pg3vDGvuS52jqDuQrr55pt18ODBgG3vvfee+vbtK+kfH+hNTU1VbW2ttd/j8WjHjh1yOBySJIfDoebmZtXX11s1GzduVFtbm4YPH27VbNmyJeA6mNPp1PXXX3/Wy0cAAODKElSAKS0t1fbt2/X000/rz3/+s6qrq7V06VIVFxdLkiIiIlRSUqInn3xSr732mvbu3av7779faWlpuueeeyT944zNt771LT388MPauXOn/vjHP2r8+PEaPXq00tLSJEk/+tGPFBMTo7Fjx2r//v16+eWXtWDBgoBLRAAA4MoV1CWkr3/961q1apUmT56sGTNmKCMjQ/Pnz1dhYaFVM3HiRB0/flyPPPKImpubdcstt2jdunWKjY21alasWKHx48frzjvvVGRkpAoKCrRw4UJrf0JCgmpqalRcXKyhQ4fqmmuuUXl5ObdQAwAASUEGGEn6zne+o+985zvn3B8REaEZM2ZoxowZ56xJSkpSdXX1eY8zcOBAvfnmm8G2BwAArgD8LSQAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYJygAsy0adMUERER8OjXr5+1/+TJkyouLlbPnj111VVXqaCgQG63O2CNxsZG5efnKz4+XsnJyZowYYJOnToVULNp0yYNGTJENptN1113naqqqi59QgAAEHaCPgPTv39/ffzxx9Zj69at1r7S0lK9/vrreuWVV7R582YdPnxY9913n7W/tbVV+fn5amlp0bZt2/TSSy+pqqpK5eXlVs2hQ4eUn5+v22+/XQ0NDSopKdFDDz2k9evXf8FRAQBAuOgW9Au6dVNqauoZ248ePaply5apurpad9xxhyRp+fLlyszM1Pbt25Wdna2amhodOHBAGzZsUEpKigYNGqSZM2dq0qRJmjZtmmJiYrRkyRJlZGRozpw5kqTMzExt3bpV8+bNU15e3hccFwAAhIOgA8z777+vtLQ0xcbGyuFwqKKiQn369FF9fb18Pp9ycnKs2n79+qlPnz6qq6tTdna26urqNGDAAKWkpFg1eXl5GjdunPbv36/Bgwerrq4uYI32mpKSkvP25fV65fV6recej0eS5PP55PP5gh3znNrXskX6O2zNYI4bCu3HDmUPXYl5wxvzhjfmNd/FzhJUgBk+fLiqqqp0/fXX6+OPP9b06dN16623at++fXK5XIqJiVFiYmLAa1JSUuRyuSRJLpcrILy072/fd74aj8ejzz77THFxcWftraKiQtOnTz9je01NjeLj44MZ86LMHNbW4Wuez9q1a7v0eGfjdDpD3UKXYt7wxrzhjXnNdeLEiYuqCyrAjBo1yvrvgQMHavjw4erbt69Wrlx5zmDRVSZPnqyysjLrucfjUXp6unJzc2W32zvsOD6fT06nU1N3R8rbFtFh617Ivmmhu3zWPvPIkSMVHR0dsj66CvOGN+YNb8xrvvYrKBcS9CWk0yUmJuprX/ua/vznP2vkyJFqaWlRc3NzwFkYt9ttfWYmNTVVO3fuDFij/S6l02s+f+eS2+2W3W4/b0iy2Wyy2WxnbI+Oju6UN9XbFiFva9cFmMvhG7OzvpaXK+YNb8wb3pjXXBc7xxf6PTDHjh3Tf//3f6tXr14aOnSooqOjVVtba+0/ePCgGhsb5XA4JEkOh0N79+5VU1OTVeN0OmW325WVlWXVnL5Ge037GgAAAEEFmH//93/X5s2b9Ze//EXbtm3Tvffeq6ioKP3whz9UQkKCxo4dq7KyMr3xxhuqr6/XAw88IIfDoezsbElSbm6usrKyNGbMGL399ttav369pkyZouLiYuvsyaOPPqoPPvhAEydO1J/+9Cc9//zzWrlypUpLSzt+egAAYKSgLiH99a9/1Q9/+EP9/e9/15e+9CXdcsst2r59u770pS9JkubNm6fIyEgVFBTI6/UqLy9Pzz//vPX6qKgorV69WuPGjZPD4VD37t1VVFSkGTNmWDUZGRlas2aNSktLtWDBAvXu3VsvvPACt1ADAABLUAHmt7/97Xn3x8bGqrKyUpWVlees6du37wXvqBkxYoTeeuutYFoDAABXEP4WEgAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgnC8UYJ555hlFRESopKTE2nby5EkVFxerZ8+euuqqq1RQUCC32x3wusbGRuXn5ys+Pl7JycmaMGGCTp06FVCzadMmDRkyRDabTdddd52qqqq+SKsAACCMXHKA2bVrl37xi19o4MCBAdtLS0v1+uuv65VXXtHmzZt1+PBh3Xfffdb+1tZW5efnq6WlRdu2bdNLL72kqqoqlZeXWzWHDh1Sfn6+br/9djU0NKikpEQPPfSQ1q9ff6ntAgCAMHJJAebYsWMqLCzUL3/5S1199dXW9qNHj2rZsmWaO3eu7rjjDg0dOlTLly/Xtm3btH37dklSTU2NDhw4oF//+tcaNGiQRo0apZkzZ6qyslItLS2SpCVLligjI0Nz5sxRZmamxo8fr+9+97uaN29eB4wMAABM1+1SXlRcXKz8/Hzl5OToySeftLbX19fL5/MpJyfH2tavXz/16dNHdXV1ys7OVl1dnQYMGKCUlBSrJi8vT+PGjdP+/fs1ePBg1dXVBazRXnP6parP83q98nq91nOPxyNJ8vl88vl8lzLmWbWvZYv0d9iawRw3FNqPHcoeuhLzhjfmDW/Ma76LnSXoAPPb3/5We/bs0a5du87Y53K5FBMTo8TExIDtKSkpcrlcVs3p4aV9f/u+89V4PB599tlniouLO+PYFRUVmj59+hnba2pqFB8ff/EDXqSZw9o6fM3zWbt2bZce72ycTmeoW+hSzBvemDe8Ma+5Tpw4cVF1QQWYjz76SI899picTqdiY2MvqbHOMnnyZJWVlVnPPR6P0tPTlZubK7vd3mHH8fl8cjqdmro7Ut62iA5b90L2TcvrsmN9XvvMI0eOVHR0dMj66CrMG96YN7wxr/nar6BcSFABpr6+Xk1NTRoyZIi1rbW1VVu2bNFzzz2n9evXq6WlRc3NzQFnYdxut1JTUyVJqamp2rlzZ8C67XcpnV7z+TuX3G637Hb7Wc++SJLNZpPNZjtje3R0dKe8qd62CHlbuy7AXA7fmJ31tbxcMW94Y97wxrzmutg5gvoQ75133qm9e/eqoaHBegwbNkyFhYXWf0dHR6u2ttZ6zcGDB9XY2CiHwyFJcjgc2rt3r5qamqwap9Mpu92urKwsq+b0Ndpr2tcAAABXtqDOwPTo0UM33HBDwLbu3burZ8+e1vaxY8eqrKxMSUlJstvt+ulPfyqHw6Hs7GxJUm5urrKysjRmzBjNnj1bLpdLU6ZMUXFxsXUG5dFHH9Vzzz2niRMn6sEHH9TGjRu1cuVKrVmzpiNmBgAAhruku5DOZ968eYqMjFRBQYG8Xq/y8vL0/PPPW/ujoqK0evVqjRs3Tg6HQ927d1dRUZFmzJhh1WRkZGjNmjUqLS3VggUL1Lt3b73wwgvKywvd50AAAMDl4wsHmE2bNgU8j42NVWVlpSorK8/5mr59+17wrpoRI0borbfe+qLtAQCAMMTfQgIAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjBNUgFm8eLEGDhwou90uu90uh8OhP/zhD9b+kydPqri4WD179tRVV12lgoICud3ugDUaGxuVn5+v+Ph4JScna8KECTp16lRAzaZNmzRkyBDZbDZdd911qqqquvQJAQBA2AkqwPTu3VvPPPOM6uvrtXv3bt1xxx26++67tX//fklSaWmpXn/9db3yyivavHmzDh8+rPvuu896fWtrq/Lz89XS0qJt27bppZdeUlVVlcrLy62aQ4cOKT8/X7fffrsaGhpUUlKihx56SOvXr++gkQEAgOm6BVN81113BTx/6qmntHjxYm3fvl29e/fWsmXLVF1drTvuuEOStHz5cmVmZmr79u3Kzs5WTU2NDhw4oA0bNiglJUWDBg3SzJkzNWnSJE2bNk0xMTFasmSJMjIyNGfOHElSZmamtm7dqnnz5ikvL6+DxgYAACYLKsCcrrW1Va+88oqOHz8uh8Oh+vp6+Xw+5eTkWDX9+vVTnz59VFdXp+zsbNXV1WnAgAFKSUmxavLy8jRu3Djt379fgwcPVl1dXcAa7TUlJSXn7cfr9crr9VrPPR6PJMnn88nn813qmGdoX8sW6e+wNYM5bii0HzuUPXQl5g1vzBvemNd8FztL0AFm7969cjgcOnnypK666iqtWrVKWVlZamhoUExMjBITEwPqU1JS5HK5JEkulysgvLTvb993vhqPx6PPPvtMcXFxZ+2roqJC06dPP2N7TU2N4uPjgx3zgmYOa+vwNc9n7dq1XXq8s3E6naFuoUsxb3hj3vDGvOY6ceLERdUFHWCuv/56NTQ06OjRo/rP//xPFRUVafPmzUE32NEmT56ssrIy67nH41F6erpyc3Nlt9s77Dg+n09Op1NTd0fK2xbRYeteyL5pobt81j7zyJEjFR0dHbI+ugrzhjfmDW/Ma772KygXEnSAiYmJ0XXXXSdJGjp0qHbt2qUFCxboBz/4gVpaWtTc3BxwFsbtdis1NVWSlJqaqp07dwas136X0uk1n79zye12y263n/PsiyTZbDbZbLYztkdHR3fKm+pti5C3tesCzOXwjdlZX8vLFfOGN+YNb8xrroud4wv/Hpi2tjZ5vV4NHTpU0dHRqq2ttfYdPHhQjY2NcjgckiSHw6G9e/eqqanJqnE6nbLb7crKyrJqTl+jvaZ9DQAAgKDOwEyePFmjRo1Snz599Omnn6q6ulqbNm3S+vXrlZCQoLFjx6qsrExJSUmy2+366U9/KofDoezsbElSbm6usrKyNGbMGM2ePVsul0tTpkxRcXGxdfbk0Ucf1XPPPaeJEyfqwQcf1MaNG7Vy5UqtWbOm46cHAABGCirANDU16f7779fHH3+shIQEDRw4UOvXr9fIkSMlSfPmzVNkZKQKCgrk9XqVl5en559/3np9VFSUVq9erXHjxsnhcKh79+4qKirSjBkzrJqMjAytWbNGpaWlWrBggXr37q0XXniBW6gBAIAlqACzbNmy8+6PjY1VZWWlKisrz1nTt2/fC95RM2LECL311lvBtAYAAK4g/C0kAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGCcoAJMRUWFvv71r6tHjx5KTk7WPffco4MHDwbUnDx5UsXFxerZs6euuuoqFRQUyO12B9Q0NjYqPz9f8fHxSk5O1oQJE3Tq1KmAmk2bNmnIkCGy2Wy67rrrVFVVdWkTAgCAsBNUgNm8ebOKi4u1fft2OZ1O+Xw+5ebm6vjx41ZNaWmpXn/9db3yyivavHmzDh8+rPvuu8/a39raqvz8fLW0tGjbtm166aWXVFVVpfLycqvm0KFDys/P1+23366GhgaVlJTooYce0vr16ztgZAAAYLpuwRSvW7cu4HlVVZWSk5NVX1+v2267TUePHtWyZctUXV2tO+64Q5K0fPlyZWZmavv27crOzlZNTY0OHDigDRs2KCUlRYMGDdLMmTM1adIkTZs2TTExMVqyZIkyMjI0Z84cSVJmZqa2bt2qefPmKS8vr4NGBwAApgoqwHze0aNHJUlJSUmSpPr6evl8PuXk5Fg1/fr1U58+fVRXV6fs7GzV1dVpwIABSklJsWry8vI0btw47d+/X4MHD1ZdXV3AGu01JSUl5+zF6/XK6/Vazz0ejyTJ5/PJ5/N9kTEDtK9li/R32JrBHDcU2o8dyh66EvOGN+YNb8xrvoud5ZIDTFtbm0pKSnTzzTfrhhtukCS5XC7FxMQoMTExoDYlJUUul8uqOT28tO9v33e+Go/Ho88++0xxcXFn9FNRUaHp06efsb2mpkbx8fGXNuR5zBzW1uFrns/atWu79Hhn43Q6Q91Cl2Le8Ma84Y15zXXixImLqrvkAFNcXKx9+/Zp69atl7pEh5o8ebLKysqs5x6PR+np6crNzZXdbu+w4/h8PjmdTk3dHSlvW0SHrXsh+6aF7tJZ+8wjR45UdHR0yProKswb3pg3vDGv+dqvoFzIJQWY8ePHa/Xq1dqyZYt69+5tbU9NTVVLS4uam5sDzsK43W6lpqZaNTt37gxYr/0updNrPn/nktvtlt1uP+vZF0my2Wyy2WxnbI+Oju6UN9XbFiFva9cFmMvhG7OzvpaXK+YNb8wb3pjXXBc7R1B3Ifn9fo0fP16rVq3Sxo0blZGREbB/6NChio6OVm1trbXt4MGDamxslMPhkCQ5HA7t3btXTU1NVo3T6ZTdbldWVpZVc/oa7TXtawAAgCtbUGdgiouLVV1drd///vfq0aOH9ZmVhIQExcXFKSEhQWPHjlVZWZmSkpJkt9v105/+VA6HQ9nZ2ZKk3NxcZWVlacyYMZo9e7ZcLpemTJmi4uJi6wzKo48+queee04TJ07Ugw8+qI0bN2rlypVas2ZNB48PAABMFNQZmMWLF+vo0aMaMWKEevXqZT1efvllq2bevHn6zne+o4KCAt12221KTU3V7373O2t/VFSUVq9eraioKDkcDv34xz/W/fffrxkzZlg1GRkZWrNmjZxOp2688UbNmTNHL7zwArdQAwAASUGegfH7L3zrcGxsrCorK1VZWXnOmr59+17wrpoRI0borbfeCqY9AABwheBvIQEAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxgk6wGzZskV33XWX0tLSFBERoVdffTVgv9/vV3l5uXr16qW4uDjl5OTo/fffD6g5cuSICgsLZbfblZiYqLFjx+rYsWMBNe+8845uvfVWxcbGKj09XbNnzw5+OgAAEJaCDjDHjx/XjTfeqMrKyrPunz17thYuXKglS5Zox44d6t69u/Ly8nTy5EmrprCwUPv375fT6dTq1au1ZcsWPfLII9Z+j8ej3Nxc9e3bV/X19Xr22Wc1bdo0LV269BJGBAAA4aZbsC8YNWqURo0addZ9fr9f8+fP15QpU3T33XdLkn71q18pJSVFr776qkaPHq13331X69at065duzRs2DBJ0qJFi/Ttb39bP//5z5WWlqYVK1aopaVFL774omJiYtS/f381NDRo7ty5AUEHAABcmYIOMOdz6NAhuVwu5eTkWNsSEhI0fPhw1dXVafTo0aqrq1NiYqIVXiQpJydHkZGR2rFjh+69917V1dXptttuU0xMjFWTl5enWbNm6ZNPPtHVV199xrG9Xq+8Xq/13OPxSJJ8Pp98Pl+Hzdi+li3S32FrBnPcUGg/dih76ErMG96YN7wxr/kudpYODTAul0uSlJKSErA9JSXF2udyuZScnBzYRLduSkpKCqjJyMg4Y432fWcLMBUVFZo+ffoZ22tqahQfH3+JE53bzGFtHb7m+axdu7ZLj3c2Tqcz1C10KeYNb8wb3pjXXCdOnLioug4NMKE0efJklZWVWc89Ho/S09OVm5sru93eYcfx+XxyOp2aujtS3raIDlv3QvZNy+uyY31e+8wjR45UdHR0yProKswb3pg3vDGv+dqvoFxIhwaY1NRUSZLb7VavXr2s7W63W4MGDbJqmpqaAl536tQpHTlyxHp9amqq3G53QE378/aaz7PZbLLZbGdsj46O7pQ31dsWIW9r1wWYy+Ebs7O+lpcr5g1vzBvemNdcFztHh/4emIyMDKWmpqq2ttba5vF4tGPHDjkcDkmSw+FQc3Oz6uvrrZqNGzeqra1Nw4cPt2q2bNkScB3M6XTq+uuvP+vlIwAAcGUJOsAcO3ZMDQ0NamhokPSPD+42NDSosbFRERERKikp0ZNPPqnXXntNe/fu1f3336+0tDTdc889kqTMzEx961vf0sMPP6ydO3fqj3/8o8aPH6/Ro0crLS1NkvSjH/1IMTExGjt2rPbv36+XX35ZCxYsCLhEBAAArlxBX0LavXu3br/9dut5e6goKipSVVWVJk6cqOPHj+uRRx5Rc3OzbrnlFq1bt06xsbHWa1asWKHx48frzjvvVGRkpAoKCrRw4UJrf0JCgmpqalRcXKyhQ4fqmmuuUXl5ObdQAwAASZcQYEaMGCG//9y3EEdERGjGjBmaMWPGOWuSkpJUXV193uMMHDhQb775ZrDtAQCAKwB/CwkAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDjdQt0AAJjg2sfXdOnx/vJMfpceLxQ642tqi/Jr9k3SDdPWy9saEbAv3L+mV9r3KAHGEF39jXm68/1ACEfhMG8ofrCE8ns0GOHw/nYGU96/L+JK+wc+3F3WAaayslLPPvusXC6XbrzxRi1atEg33XRTqNsCLnvB/KDmH/TLU0f9Y8v7GzpdEZhOf3+lK+v9vWw/A/Pyyy+rrKxMTzzxhPbs2aMbb7xReXl5ampqCnVrAAAgxC7bADN37lw9/PDDeuCBB5SVlaUlS5YoPj5eL774YqhbAwAAIXZZXkJqaWlRfX29Jk+ebG2LjIxUTk6O6urqzvoar9crr9drPT969Kgk6ciRI/L5fB3Wm8/n04kTJ9TNF6nWtivjdF23Nr9OnGi7YmZm3vDGvOGNebvO3//+905Z99NPP5Uk+f3+8xf6L0N/+9vf/JL827ZtC9g+YcIE/0033XTW1zzxxBN+STx48ODBgwePMHh89NFH580Kl+UZmEsxefJklZWVWc/b2tp05MgR9ezZUxERHZdKPR6P0tPT9dFHH8lut3fYupezK21m5g1vzBvemNd8fr9fn376qdLS0s5bd1kGmGuuuUZRUVFyu90B291ut1JTU8/6GpvNJpvNFrAtMTGxs1qU3W4Pm2+Wi3Wlzcy84Y15wxvzmi0hIeGCNZflh3hjYmI0dOhQ1dbWWtva2tpUW1srh8MRws4AAMDl4LI8AyNJZWVlKioq0rBhw3TTTTdp/vz5On78uB544IFQtwYAAELssg0wP/jBD/Q///M/Ki8vl8vl0qBBg7Ru3TqlpKSEtC+bzaYnnnjijMtV4exKm5l5wxvzhjfmvXJE+P0Xuk8JAADg8nJZfgYGAADgfAgwAADAOAQYAABgHAIMAAAwDgEmSJWVlbr22msVGxur4cOHa+fOnaFuqVNUVFTo61//unr06KHk5GTdc889OnjwYKjb6jLPPPOMIiIiVFJSEupWOs3f/vY3/fjHP1bPnj0VFxenAQMGaPfu3aFuq1O0trZq6tSpysjIUFxcnL7yla9o5syZF/5bKwbZsmWL7rrrLqWlpSkiIkKvvvpqwH6/36/y8nL16tVLcXFxysnJ0fvvvx+aZjvA+eb1+XyaNGmSBgwYoO7duystLU3333+/Dh8+HLqGv6ALvb+ne/TRRxUREaH58+d3WX+hQIAJwssvv6yysjI98cQT2rNnj2688Ubl5eWpqakp1K11uM2bN6u4uFjbt2+X0+mUz+dTbm6ujh8/HurWOt2uXbv0i1/8QgMHDgx1K53mk08+0c0336zo6Gj94Q9/0IEDBzRnzhxdffXVoW6tU8yaNUuLFy/Wc889p3fffVezZs3S7NmztWjRolC31mGOHz+uG2+8UZWVlWfdP3v2bC1cuFBLlizRjh071L17d+Xl5enkyZNd3GnHON+8J06c0J49ezR16lTt2bNHv/vd73Tw4EH98z//cwg67RgXen/brVq1Stu3b7/gr+EPCx3xxxevFDfddJO/uLjYet7a2upPS0vzV1RUhLCrrtHU1OSX5N+8eXOoW+lUn376qf+rX/2q3+l0+r/5zW/6H3vssVC31CkmTZrkv+WWW0LdRpfJz8/3P/jggwHb7rvvPn9hYWGIOupckvyrVq2ynre1tflTU1P9zz77rLWtubnZb7PZ/L/5zW9C0GHH+vy8Z7Nz506/JP+HH37YNU11onPN+9e//tX/5S9/2b9v3z5/3759/fPmzevy3roSZ2AuUktLi+rr65WTk2Nti4yMVE5Ojurq6kLYWdc4evSoJCkpKSnEnXSu4uJi5efnB7zP4ei1117TsGHD9L3vfU/JyckaPHiwfvnLX4a6rU7zjW98Q7W1tXrvvfckSW+//ba2bt2qUaNGhbizrnHo0CG5XK6A7+uEhAQNHz78ivj5Jf3jZ1hERESn/o28UGpra9OYMWM0YcIE9e/fP9TtdInL9jfxXm7+93//V62trWf8JuCUlBT96U9/ClFXXaOtrU0lJSW6+eabdcMNN4S6nU7z29/+Vnv27NGuXbtC3Uqn++CDD7R48WKVlZXpP/7jP7Rr1y797Gc/U0xMjIqKikLdXod7/PHH5fF41K9fP0VFRam1tVVPPfWUCgsLQ91al3C5XJJ01p9f7fvC2cmTJzVp0iT98Ic/DKs/eHi6WbNmqVu3bvrZz34W6la6DAEGF1RcXKx9+/Zp69atoW6l03z00Ud67LHH5HQ6FRsbG+p2Ol1bW5uGDRump59+WpI0ePBg7du3T0uWLAnLALNy5UqtWLFC1dXV6t+/vxoaGlRSUqK0tLSwnBf/x+fz6fvf/778fr8WL14c6nY6RX19vRYsWKA9e/YoIiIi1O10GS4hXaRrrrlGUVFRcrvdAdvdbrdSU1ND1FXnGz9+vFavXq033nhDvXv3DnU7naa+vl5NTU0aMmSIunXrpm7dumnz5s1auHChunXrptbW1lC32KF69eqlrKysgG2ZmZlqbGwMUUeda8KECXr88cc1evRoDRgwQGPGjFFpaakqKipC3VqXaP8ZdaX9/GoPLx9++KGcTmfYnn1588031dTUpD59+lg/vz788EP927/9m6699tpQt9dpCDAXKSYmRkOHDlVtba21ra2tTbW1tXI4HCHsrHP4/X6NHz9eq1at0saNG5WRkRHqljrVnXfeqb1796qhocF6DBs2TIWFhWpoaFBUVFSoW+xQN9988xm3xb/33nvq27dviDrqXCdOnFBkZOCPu6ioKLW1tYWoo66VkZGh1NTUgJ9fHo9HO3bsCMufX9L/hZf3339fGzZsUM+ePUPdUqcZM2aM3nnnnYCfX2lpaZowYYLWr18f6vY6DZeQglBWVqaioiINGzZMN910k+bPn6/jx4/rgQceCHVrHa64uFjV1dX6/e9/rx49eljXyRMSEhQXFxfi7jpejx49zvh8T/fu3dWzZ8+w/NxPaWmpvvGNb+jpp5/W97//fe3cuVNLly7V0qVLQ91ap7jrrrv01FNPqU+fPurfv7/eeustzZ07Vw8++GCoW+swx44d05///Gfr+aFDh9TQ0KCkpCT16dNHJSUlevLJJ/XVr35VGRkZmjp1qtLS0nTPPfeErukv4Hzz9urVS9/97ne1Z88erV69Wq2trdbPsKSkJMXExISq7Ut2off38wEtOjpaqampuv7667u61a4T6tugTLNo0SJ/nz59/DExMf6bbrrJv3379lC31CkknfWxfPnyULfWZcL5Nmq/3+9//fXX/TfccIPfZrP5+/Xr51+6dGmoW+o0Ho/H/9hjj/n79Onjj42N9f/TP/2T///9v//n93q9oW6tw7zxxhtn/X+2qKjI7/f/41bqqVOn+lNSUvw2m81/5513+g8ePBjapr+A88176NChc/4Me+ONN0Ld+iW50Pv7eVfCbdQRfn8Y/SpKAABwReAzMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAY5/8Dryo8NpaTnRMAAAAASUVORK5CYII=",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"# THIS CHECKS IF Q SAMPLE REALLY SAMPLES from Q.\n",
"\n",
"x = torch.tensor([[0]])\n",
"collects = []\n",
"for idx in range(10000):\n",
" sample = d3pm.q_sample(x, torch.tensor([200], device = 'cuda:0'), torch.rand((*x.shape, d3pm.num_classses), device='cuda:0'))\n",
" collects.append(sample.item())\n",
"\n",
"# plot histogram\n",
"import pandas as pd\n",
"\n",
"df = pd.DataFrame(collects, columns=['sample'])\n",
"df['sample'].hist(bins=16)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"qmats = torch.arange(0, 1000).reshape(10, 10, 10)\n",
"image_as_pixels = torch.randint(0, 10, (1, 1, 1))\n",
"print(image_as_pixels)\n",
"qmats[image_as_pixels].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"qmats[1, 0, 1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cu122py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}