Repository: cloneofsimo/d3pm Branch: main Commit: 3ceb63725ee2 Files: 10 Total size: 215.3 KB Directory structure: gitextract_jcwmbccq/ ├── .gitignore ├── CITATION.cff ├── d3pm_runner.py ├── d3pm_runner_cifar10.py ├── dit.py ├── lm.py ├── lm_deepspeed.py ├── readme.md ├── run_multigpu.sh └── test.ipynb ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ presetup.sh setup.sh cu122py310 data test wandb run.sh *.pyc ckpt ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "Citations would be appreciated if you end up using this tool! I currently go by Simo Ryu" authors: - family-names: "Ryu" given-names: "Simo" orcid: "https://orcid.org/0009-0008-0017-2677" title: "Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch" version: 0.0.1 date-released: 2024-04 url: "https://github.com/cloneofsimo/d3pm" ================================================ FILE: d3pm_runner.py ================================================ import numpy as np import torch import torch.nn as nn from PIL import Image from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST from torchvision.utils import make_grid from tqdm import tqdm blk = lambda ic, oc: nn.Sequential( nn.Conv2d(ic, oc, 5, padding=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), nn.Conv2d(oc, oc, 5, padding=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), nn.Conv2d(oc, oc, 5, padding=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), ) blku = lambda ic, oc: nn.Sequential( nn.Conv2d(ic, oc, 5, padding=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), nn.Conv2d(oc, oc, 5, padding=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), nn.Conv2d(oc, oc, 5, padding=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), nn.ConvTranspose2d(oc, oc, 2, stride=2), nn.GroupNorm(oc // 8, oc), nn.LeakyReLU(), ) class DummyX0Model(nn.Module): def __init__(self, n_channel: int, N: int = 16) -> None: super(DummyX0Model, self).__init__() self.down1 = blk(n_channel, 16) self.down2 = blk(16, 32) self.down3 = blk(32, 64) self.down4 = blk(64, 512) self.down5 = blk(512, 512) self.up1 = blku(512, 512) self.up2 = blku(512 + 512, 64) self.up3 = blku(64, 32) self.up4 = blku(32, 16) self.convlast = blk(16, 16) self.final = nn.Conv2d(16, N * n_channel, 1, bias=False) self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8) self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8) self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8) self.cond_embedding_1 = nn.Embedding(10, 16) self.cond_embedding_2 = nn.Embedding(10, 32) self.cond_embedding_3 = nn.Embedding(10, 64) self.cond_embedding_4 = nn.Embedding(10, 512) self.cond_embedding_5 = nn.Embedding(10, 512) self.cond_embedding_6 = nn.Embedding(10, 64) self.temb_1 = nn.Linear(32, 16) self.temb_2 = nn.Linear(32, 32) self.temb_3 = nn.Linear(32, 64) self.temb_4 = nn.Linear(32, 512) self.N = N def forward(self, x, t, cond) -> torch.Tensor: x = (2 * x.float() / self.N) - 1.0 t = t.float().reshape(-1, 1) / 1000 t_features = [torch.sin(t * 3.1415 * 2**i) for i in range(16)] + [ torch.cos(t * 3.1415 * 2**i) for i in range(16) ] tx = torch.cat(t_features, dim=1).to(x.device) t_emb_1 = self.temb_1(tx).unsqueeze(-1).unsqueeze(-1) t_emb_2 = self.temb_2(tx).unsqueeze(-1).unsqueeze(-1) t_emb_3 = self.temb_3(tx).unsqueeze(-1).unsqueeze(-1) t_emb_4 = self.temb_4(tx).unsqueeze(-1).unsqueeze(-1) cond_emb_1 = self.cond_embedding_1(cond).unsqueeze(-1).unsqueeze(-1) cond_emb_2 = self.cond_embedding_2(cond).unsqueeze(-1).unsqueeze(-1) cond_emb_3 = self.cond_embedding_3(cond).unsqueeze(-1).unsqueeze(-1) cond_emb_4 = self.cond_embedding_4(cond).unsqueeze(-1).unsqueeze(-1) cond_emb_5 = self.cond_embedding_5(cond).unsqueeze(-1).unsqueeze(-1) cond_emb_6 = self.cond_embedding_6(cond).unsqueeze(-1).unsqueeze(-1) x1 = self.down1(x) + t_emb_1 + cond_emb_1 x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2 + cond_emb_2 x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3 + cond_emb_3 x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4 + cond_emb_4 x5 = self.down5(nn.functional.avg_pool2d(x4, 2)) x5 = ( self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2)) .transpose(1, 2) .reshape(x5.shape) ) y = self.up1(x5) + cond_emb_5 y = ( self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)) .transpose(1, 2) .reshape(y.shape) ) y = self.up2(torch.cat([x4, y], dim=1)) + cond_emb_6 y = ( self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)) .transpose(1, 2) .reshape(y.shape) ) y = self.up3(y) y = self.up4(y) y = self.convlast(y) y = self.final(y) # reshape to B, C, H, W, N y = ( y.reshape(y.shape[0], -1, self.N, *x.shape[2:]) .transpose(2, -1) .contiguous() ) return y class D3PM(nn.Module): def __init__( self, x0_model: nn.Module, n_T: int, num_classes: int = 10, forward_type="uniform", hybrid_loss_coeff=0.001, ) -> None: super(D3PM, self).__init__() self.x0_model = x0_model self.n_T = n_T self.hybrid_loss_coeff = hybrid_loss_coeff steps = torch.arange(n_T + 1, dtype=torch.float64) / n_T alpha_bar = torch.cos((steps + 0.008) / 1.008 * torch.pi / 2) self.beta_t = torch.minimum( 1 - alpha_bar[1:] / alpha_bar[:-1], torch.ones_like(alpha_bar[1:]) * 0.999 ) # self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)] self.eps = 1e-6 self.num_classses = num_classes q_onestep_mats = [] q_mats = [] # these are cumulative for beta in self.beta_t: if forward_type == "uniform": mat = torch.ones(num_classes, num_classes) * beta / num_classes mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes) q_onestep_mats.append(mat) else: raise NotImplementedError q_one_step_mats = torch.stack(q_onestep_mats, dim=0) q_one_step_transposed = q_one_step_mats.transpose( 1, 2 ) # this will be used for q_posterior_logits q_mat_t = q_onestep_mats[0] q_mats = [q_mat_t] for idx in range(1, self.n_T): q_mat_t = q_mat_t @ q_onestep_mats[idx] q_mats.append(q_mat_t) q_mats = torch.stack(q_mats, dim=0) self.logit_type = "logit" # register self.register_buffer("q_one_step_transposed", q_one_step_transposed) self.register_buffer("q_mats", q_mats) assert self.q_mats.shape == ( self.n_T, num_classes, num_classes, ), self.q_mats.shape def _at(self, a, t, x): # t is 1-d, x is integer value of 0 to num_classes - 1 bs = t.shape[0] t = t.reshape((bs, *[1] * (x.dim() - 1))) # out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m] return a[t - 1, x, :] def q_posterior_logits(self, x_0, x_t, t): # if t == 1, this means we return the L_0 loss, so directly try to x_0 logits. # otherwise, we return the L_{t-1} loss. # Also, we never have t == 0. # if x_0 is integer, we convert it to one-hot. if x_0.dtype == torch.int64 or x_0.dtype == torch.int32: x_0_logits = torch.log( torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps ) else: x_0_logits = x_0.clone() assert x_0_logits.shape == x_t.shape + (self.num_classses,), print( f"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}" ) # Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that. # fact1 is "guess of x_{t-1}" from x_t # fact2 is "guess of x_{t-1}" from x_0 fact1 = self._at(self.q_one_step_transposed, t, x_t) softmaxed = torch.softmax(x_0_logits, dim=-1) # bs, ..., num_classes qmats2 = self.q_mats[t - 2].to(dtype=softmaxed.dtype) # bs, num_classes, num_classes fact2 = torch.einsum("b...c,bcd->b...d", softmaxed, qmats2) out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps) t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim()))) bc = torch.where(t_broadcast == 1, x_0_logits, out) return bc def vb(self, dist1, dist2): # flatten dist1 and dist2 dist1 = dist1.flatten(start_dim=0, end_dim=-2) dist2 = dist2.flatten(start_dim=0, end_dim=-2) out = torch.softmax(dist1 + self.eps, dim=-1) * ( torch.log_softmax(dist1 + self.eps, dim=-1) - torch.log_softmax(dist2 + self.eps, dim=-1) ) return out.sum(dim=-1).mean() def q_sample(self, x_0, t, noise): # forward process, x_0 is the clean input. logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps) noise = torch.clip(noise, self.eps, 1.0) gumbel_noise = -torch.log(-torch.log(noise)) return torch.argmax(logits + gumbel_noise, dim=-1) def model_predict(self, x_0, t, cond): # this part exists because in general, manipulation of logits from model's logit # so they are in form of x_0's logit might be independent to model choice. # for example, you can convert 2 * N channel output of model output to logit via get_logits_from_logistic_pars # they introduce at appendix A.8. predicted_x0_logits = self.x0_model(x_0, t, cond) return predicted_x0_logits def forward(self, x: torch.Tensor, cond: torch.Tensor = None) -> torch.Tensor: """ Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model. x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1 """ t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device) x_t = self.q_sample( x, t, torch.rand((*x.shape, self.num_classses), device=x.device) ) # x_t is same shape as x assert x_t.shape == x.shape, print( f"x_t.shape: {x_t.shape}, x.shape: {x.shape}" ) # we use hybrid loss. predicted_x0_logits = self.model_predict(x_t, t, cond) # based on this, we first do vb loss. true_q_posterior_logits = self.q_posterior_logits(x, x_t, t) pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t) vb_loss = self.vb(true_q_posterior_logits, pred_q_posterior_logits) predicted_x0_logits = predicted_x0_logits.flatten(start_dim=0, end_dim=-2) x = x.flatten(start_dim=0, end_dim=-1) ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x) return self.hybrid_loss_coeff * vb_loss + ce_loss, { "vb_loss": vb_loss.detach().item(), "ce_loss": ce_loss.detach().item(), } def p_sample(self, x, t, cond, noise): predicted_x0_logits = self.model_predict(x, t, cond) pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t) noise = torch.clip(noise, self.eps, 1.0) not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim()))) gumbel_noise = -torch.log(-torch.log(noise)) sample = torch.argmax( pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1 ) return sample def sample(self, x, cond=None): for t in reversed(range(1, self.n_T)): t = torch.tensor([t] * x.shape[0], device=x.device) x = self.p_sample( x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device) ) return x def sample_with_image_sequence(self, x, cond=None, stride=10): steps = 0 images = [] for t in reversed(range(1, self.n_T)): t = torch.tensor([t] * x.shape[0], device=x.device) x = self.p_sample( x, t, cond, torch.rand((*x.shape, self.num_classses), device=x.device) ) steps += 1 if steps % stride == 0: images.append(x) # if last step is not divisible by stride, we add the last image. if steps % stride != 0: images.append(x) return images if __name__ == "__main__": N = 2 # number of classes for discretized state per pixel d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes=N, hybrid_loss_coeff=0.0).cuda() print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}") dataset = MNIST( "./data", train=True, download=True, transform=transforms.Compose( [ transforms.ToTensor(), transforms.Pad(2), ] ), ) dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32) optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=1e-3) d3pm.train() n_epoch = 400 device = "cuda" global_step = 0 for i in range(n_epoch): pbar = tqdm(dataloader) loss_ema = None for x, cond in pbar: optim.zero_grad() x = x.to(device) cond = cond.to(device) # discritize x to N bins x = (x * (N - 1)).round().long().clamp(0, N - 1) loss, info = d3pm(x, cond) loss.backward() norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.1) with torch.no_grad(): param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()]) if loss_ema is None: loss_ema = loss.item() else: loss_ema = 0.99 * loss_ema + 0.01 * loss.item() pbar.set_description( f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}" ) optim.step() global_step += 1 if global_step % 300 == 1: d3pm.eval() with torch.no_grad(): cond = torch.arange(0, 4).cuda() % 10 init_noise = torch.randint(0, N, (4, 1, 32, 32)).cuda() images = d3pm.sample_with_image_sequence( init_noise, cond, stride=40 ) # image sequences to gif gif = [] for image in images: x_as_image = make_grid(image.float() / (N - 1), nrow=2) img = x_as_image.permute(1, 2, 0).cpu().numpy() img = (img * 255).astype(np.uint8) gif.append(Image.fromarray(img)) gif[0].save( f"contents/sample_{global_step}.gif", save_all=True, append_images=gif[1:], duration=100, loop=0, ) last_img = gif[-1] last_img.save(f"contents/sample_{global_step}_last.png") d3pm.train() ================================================ FILE: d3pm_runner_cifar10.py ================================================ import numpy as np import torch from PIL import Image from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.utils import make_grid from tqdm import tqdm import wandb from d3pm_runner import D3PM from dit import DiT_Llama if __name__ == "__main__": wandb.init(project="d3pm_cifar10") N = 8 # number of classes for discretized state per pixel d3pm = D3PM( DiT_Llama(3, N, dim=1024), 1000, num_classes=N, hybrid_loss_coeff=0.0 ).cuda() print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}") dataset = CIFAR10( "./data", train=True, download=True, transform=transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] ), ) dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=16) optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=2e-5) d3pm.train() n_epoch = 4000 device = "cuda" global_step = 0 for i in range(n_epoch): pbar = tqdm(dataloader) loss_ema = None for x, cond in pbar: optim.zero_grad() x = x.to(device) cond = cond.to(device) # discritize x to N bins x_cat = (x * (N - 1)).round().long().clamp(0, N - 1) loss, info = d3pm(x_cat, cond) loss.backward() norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 5.0) with torch.no_grad(): param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()]) if loss_ema is None: loss_ema = loss.item() else: loss_ema = 0.99 * loss_ema + 0.01 * loss.item() if global_step % 10 == 0: wandb.log( { "train_loss": loss, "train_grad_norm": norm, "train_param_norm": param_norm, } ) pbar.set_description( f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}" ) optim.step() global_step += 1 if global_step % 600 == 1: d3pm.eval() with torch.no_grad(): cond = torch.arange(0, 16).cuda() % 10 init_noise = torch.randint(0, N, (16, 3, 32, 32)).cuda() images = d3pm.sample_with_image_sequence( init_noise, cond, stride=40 ) # image sequences to gif gif = [] for image in images: x_from_dataloader = x_cat[:16].cpu() / (N - 1) this_image = image.float().cpu() / (N - 1) all_images = torch.cat([x_from_dataloader, this_image], dim=0) x_as_image = make_grid(all_images, nrow=4) img = x_as_image.permute(1, 2, 0).cpu().numpy() img = (img * 255).astype(np.uint8) gif.append(Image.fromarray(img)) gif[0].save( f"contents/sample_{global_step}.gif", save_all=True, append_images=gif[1:], duration=100, loop=0, ) last_img = gif[-1] last_img.save(f"contents/sample_{global_step}_last.png") # log images wandb.log( { "sample": wandb.Image(last_img), } ) d3pm.train() ================================================ FILE: dit.py ================================================ # Code heavilty based on https://github.com/Alpha-VLLM/LLaMA2-Accessory import math import torch import torch.nn as nn import torch.nn.functional as F def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half) / half ).to(t.device) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( dtype=next(self.parameters()).dtype ) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = int(dropout_prob > 0) self.embedding_table = nn.Embedding( num_classes + use_cfg_embedding, hidden_size ) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob drop_ids = drop_ids.cuda() drop_ids = drop_ids.to(labels.device) else: drop_ids = force_drop_ids == 1 labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels, train, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class Attention(nn.Module): def __init__(self, dim, n_heads): super().__init__() self.n_heads = n_heads self.n_rep = 1 self.head_dim = dim // n_heads self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) self.wk = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) self.wv = nn.Linear(dim, self.n_heads * self.head_dim, bias=False) self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim) self.k_norm = nn.LayerNorm(self.n_heads * self.head_dim) @staticmethod def reshape_for_broadcast(freqs_cis, x): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @staticmethod def apply_rotary_emb(xq, xk, freqs_cis): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out, xk_out def forward(self, x, freqs_cis): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) dtype = xq.dtype xq = self.q_norm(xq) xk = self.k_norm(xk) xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_heads, self.head_dim) xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) xq, xk = xq.to(dtype), xk.to(dtype) output = F.scaled_dot_product_attention( xq.permute(0, 2, 1, 3), xk.permute(0, 2, 1, 3), xv.permute(0, 2, 1, 3), dropout_p=0.0, is_causal=False, ).permute(0, 2, 1, 3) output = output.flatten(-2) return self.wo(output) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None): super().__init__() hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) def _forward_silu_gating(self, x1, x3): return F.silu(x1) * x3 def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) class TransformerBlock(nn.Module): def __init__( self, layer_id, dim, n_heads, multiple_of, ffn_dim_multiplier, norm_eps, ): super().__init__() self.dim = dim self.head_dim = dim // n_heads self.attention = Attention(dim, n_heads) self.feed_forward = FeedForward( dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) self.layer_id = layer_id self.attention_norm = nn.LayerNorm(dim, eps=norm_eps) self.ffn_norm = nn.LayerNorm(dim, eps=norm_eps) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(min(dim, 1024), 6 * dim, bias=True), ) def forward(self, x, freqs_cis, adaln_input=None): if adaln_input is not None: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.adaLN_modulation(adaln_input).chunk(6, dim=1) ) x = x + gate_msa.unsqueeze(1) * self.attention( modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis ) x = x + gate_mlp.unsqueeze(1) * self.feed_forward( modulate(self.ffn_norm(x), shift_mlp, scale_mlp) ) else: x = x + self.attention(self.attention_norm(x), freqs_cis) x = x + self.feed_forward(self.ffn_norm(x)) return x class FinalLayer(nn.Module): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear( hidden_size, patch_size * patch_size * out_channels, bias=True ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias=True), ) # # init zero nn.init.constant_(self.linear.weight, 0) nn.init.constant_(self.linear.bias, 0) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DDiT_Llama(nn.Module): def __init__( self, N=256, dim=512, n_layers=5, n_heads=16, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-5, learn_gating=False, ): super().__init__() self.N = N self.learn_gating = learn_gating if self.learn_gating: self.out_channel = N * 2 else: self.out_channel = N self.embedder = nn.Embedding(N, dim) self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.layers = nn.ModuleList( [ TransformerBlock( layer_id, dim, n_heads, multiple_of, ffn_dim_multiplier, norm_eps, ) for layer_id in range(n_layers) ] ) self.final_layer = FinalLayer(dim, 1, self.out_channel) self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 4096) def forward(self, x, t, cond=None): self.freqs_cis = self.freqs_cis.to(x.device) x_onehot = torch.nn.functional.one_hot(x, self.N).to( x.device, dtype=next(self.parameters()).dtype ) x = self.embedder(x) adaln_input = self.t_embedder(t) for layer in self.layers: x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input) x = self.final_layer(x, adaln_input) if self.learn_gating: x, gate = x.chunk(2, dim=-1) return x + x_onehot * (1 + gate).abs() else: return x + x_onehot class DiT_Llama(nn.Module): def __init__( self, in_channels=3, N=8, input_size=32, patch_size=2, dim=512, n_layers=5, n_heads=16, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-5, class_dropout_prob=0.1, num_classes=10, learn_sigma=True, ): super().__init__() self.N = N self.learn_sigma = learn_sigma self.in_channels = in_channels self.out_channels = N * in_channels * 2 self.input_size = input_size self.patch_size = patch_size self.init_conv_seq = nn.Sequential( nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1), nn.SiLU(), nn.GroupNorm(32, dim // 2), nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1), nn.SiLU(), nn.GroupNorm(32, dim // 2), ) self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True) nn.init.constant_(self.x_embedder.bias, 0) self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedder(num_classes, min(dim, 1024), class_dropout_prob) self.layers = nn.ModuleList( [ TransformerBlock( layer_id, dim, n_heads, multiple_of, ffn_dim_multiplier, norm_eps, ) for layer_id in range(n_layers) ] ) self.final_layer = FinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = DiT_Llama.precompute_freqs_cis(dim // n_heads, 4096) def unpatchify(self, x): c = self.out_channels p = self.patch_size h = w = int(x.shape[1] ** 0.5) x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs def patchify(self, x): B, C, H, W = x.size() x = x.view( B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size, ) x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) return x def forward(self, x, t, y): self.freqs_cis = self.freqs_cis.to(x.device) x_onehot = torch.nn.functional.one_hot(x, self.N).float().to(x.device) x = (2 * x.float() / (self.N - 1)) - 1.0 x = self.init_conv_seq(x) x = self.patchify(x) x = self.x_embedder(x) t = self.t_embedder(t) # (N, D) y = self.y_embedder(y, self.training) # (N, D) adaln_input = t + y for layer in self.layers: x = layer(x, self.freqs_cis[: x.size(1)], adaln_input=adaln_input) x = self.final_layer(x, adaln_input) x = self.unpatchify(x) # (N, out_channels, H, W) x, gate = ( x.reshape(x.shape[0], -1, self.N * 2, *x.shape[2:]) .transpose(2, -1) .contiguous() ).chunk(2, dim=-1) return x + x_onehot * (1 + gate).abs() # x = (x.reshape(x.shape[0], -1, self.N, *x.shape[2:]) # .transpose(2, -1) # .contiguous() # ) # return x def forward_with_cfg(self, x, t, y, cfg_scale): half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, t, y) eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1) @staticmethod def precompute_freqs_cis(dim, end, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def DiT_Llama_600M_patch2(**kwargs): return DiT_Llama(patch_size=2, dim=256, n_layers=16, n_heads=32, **kwargs) def DiT_Llama_3B_patch2(**kwargs): return DiT_Llama(patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs) if __name__ == "__main__": model = DiT_Llama_600M_patch2() model.eval() x = torch.randint(0, 8, (4, 3, 32, 32)) t = torch.randint(0, 1000, (4,)) y = torch.randint(0, 10, (4,)) out = model(x, t, y) print(out) # cuda ver model = model.cuda() x = x.cuda() t = t.cuda() y = y.cuda() out = model(x, t, y) print(out) ================================================ FILE: lm.py ================================================ import math import torch from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from transformers import get_scheduler import wandb from d3pm_runner import D3PM from dit import DDiT_Llama class WikiTextDataset(Dataset): def __init__(self, tokenizer=None, type_path="train", max_length=512, debug=False): if debug: vernum = 2 else: vernum = 103 self.vernum = vernum self.dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split="train") self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return ( int(len(self.dataset) * 0.1) if (self.vernum == 103) else len(self.dataset) ) def __getitem__(self, idx): text = self.dataset[idx]["text"] # logger.info(text) if self.tokenizer is not None: inputs = self.tokenizer( text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) input_ids = inputs.input_ids.squeeze() else: # use byte encoding seq = list(text.encode("utf-8")) if len(seq) < self.max_length: seq += [0] * (self.max_length - len(seq)) else: seq = seq[: self.max_length] input_ids = torch.tensor(seq, dtype=torch.long) return {"input_ids": input_ids} if __name__ == "__main__": wandb.init(project="d3pm_wiki") N = 256 max_length = 256 num_train_epochs = 5 d3pm = D3PM( DDiT_Llama(N, dim=512, n_layers=6), 1000, num_classes=N, hybrid_loss_coeff=0.0 ).cuda() print(f"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}") dataset = WikiTextDataset(max_length=max_length, debug=False) dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=8) optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=2e-4) lr_scheduler = get_scheduler( name="linear", optimizer=optim, num_warmup_steps=100, num_training_steps=num_train_epochs * math.ceil(len(dataloader)), ) d3pm.train() device = "cuda" global_step = 0 for i in range(num_train_epochs): pbar = tqdm(dataloader) loss_ema = None for x in pbar: optim.zero_grad() x = x["input_ids"].to(device) # discritize x to N bins loss, info = d3pm(x) loss.backward() norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 5.0) with torch.no_grad(): param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()]) if loss_ema is None: loss_ema = loss.item() else: loss_ema = 0.99 * loss_ema + 0.01 * loss.item() if global_step % 10 == 0: wandb.log( { "train_loss": loss, "train_grad_norm": norm, "train_param_norm": param_norm, } ) pbar.set_description( f"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}" ) optim.step() lr_scheduler.step() global_step += 1 if global_step % 600 == 1: d3pm.eval() with torch.no_grad(): init_noise = torch.randint(0, N, (16, max_length)).cuda() outputs = d3pm.sample_with_image_sequence( init_noise, None, stride=40 ) gen_outputs = [] total = 0 # back to sentence, byte to utf-8 for _i in range(16): sent = outputs[-1][_i].cpu().tolist() correctly_parsed = True try: sent = b"".join([bytes([i]) for i in sent]).decode("utf-8") except: # if there is error, just unicodec correctly_parsed = False sent = "".join([chr(i) for i in sent]) sent = ( f"[{_i}] Sample Correctly parsed: " + str(correctly_parsed) + "\n" + sent ) total += 1 if correctly_parsed else 0 gen_outputs.append(sent) print(sent) # make a nice html to show the generated outputs html_formatted = "
".join(gen_outputs) # log text wandb.log( { "generated_text": wandb.Html(html_formatted), "correctly_parsed": total, } ) d3pm.train() if global_step % 3000 == 1: torch.save(d3pm.state_dict(), f"ckpt/d3pm_wiki_{global_step}.pth") print(f"Model saved at {global_step}") ================================================ FILE: lm_deepspeed.py ================================================ import math import os import random import click import deepspeed import numpy as np import torch from datasets import load_dataset from deepspeed import get_accelerator from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.utils import logger from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm from transformers import default_data_collator, get_scheduler import wandb from d3pm_runner import D3PM from dit import DDiT_Llama class WikiTextDataset(Dataset): def __init__( self, tokenizer=None, type_path="train", max_seq_length=512, debug=False ): if debug: self.dataset = load_dataset("wikitext", f"wikitext-2-raw-v1", split="test") else: self.dataset = load_dataset( "wikimedia/wikipedia", "20231101.en", split="train" ) self.tokenizer = tokenizer self.max_seq_length = max_seq_length def __len__(self): return len(self.dataset) def __getitem__(self, idx): text = self.dataset[idx]["text"] # logger.info(text) if self.tokenizer is not None: inputs = self.tokenizer( text, max_length=self.max_seq_length, padding="max_length", truncation=True, return_tensors="pt", ) input_ids = inputs.input_ids.squeeze() else: # use byte encoding seq = list(text.encode("utf-8")) if len(seq) < self.max_seq_length: seq += [0] * (self.max_seq_length - len(seq)) else: seq = seq[: self.max_seq_length] input_ids = torch.tensor(seq, dtype=torch.long) return {"input_ids": input_ids} def _z3_params_to_fetch(param_list): return [ p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE ] def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0): zero_stage_3 = zero_stage == 3 os.makedirs(save_dir, exist_ok=True) WEIGHTS_NAME = "pytorch_model.bin" output_model_file = os.path.join(save_dir, WEIGHTS_NAME) model_to_save = model_ema.module if hasattr(model_ema, "module") else model_ema if not zero_stage_3: if global_rank == 0: torch.save(model_to_save.state_dict(), output_model_file) else: output_state_dict = {} for k, v in model_to_save.named_parameters(): if hasattr(v, "ds_id"): with deepspeed.zero.GatheredParameters( _z3_params_to_fetch([v]), enabled=zero_stage_3 ): v_p = v.data.cpu() else: v_p = v.cpu() if global_rank == 0 and "lora" not in k: output_state_dict[k] = v_p if global_rank == 0: torch.save(output_state_dict, output_model_file) del output_state_dict def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) @click.command() @click.option("--local_rank", default=-1, help="Local rank") @click.option("--max_seq_length", default=256, help="Max sequence length") @click.option("--num_train_epochs", default=5, help="Number of training epochs") @click.option("--learning_rate", default=1e-4, help="Learning rate") @click.option("--offload", default=False, help="Offload") @click.option("--train_batch_size", default=1024, help="Train batch size") @click.option( "--per_device_train_batch_size", default=64, help="Per device train batch size" ) @click.option("--zero_stage", default=2, help="Zero stage") @click.option("--seed", default=42, help="Seed") @click.option("--run_name", default=None, help="Run name") def main( local_rank, max_seq_length=256, num_train_epochs=5, learning_rate=1e-4, offload=False, train_batch_size=512, per_device_train_batch_size=64, zero_stage=2, seed=42, run_name=None, ): # first, set the seed set_seed(seed) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_flash_sdp(False) if run_name is None: run_name = f"LR:{learning_rate}_max_seq_length:{max_seq_length}_num_train_epochs:{num_train_epochs}_offload:{offload}" if local_rank == -1: device = torch.device(get_accelerator().device_name()) else: get_accelerator().set_device(local_rank) device = torch.device(get_accelerator().device_name(), local_rank) # Initializes the distributed backend which will take care of sychronizing nodes/GPUs deepspeed.init_distributed() offload_device = "cpu" if offload else "none" ds_config = { "train_micro_batch_size_per_gpu": per_device_train_batch_size, "train_batch_size": train_batch_size, "zero_optimization": { "stage": zero_stage, "offload_param": {"device": offload_device}, "offload_optimizer": {"device": offload_device}, "stage3_param_persistence_threshold": 1e4, "stage3_max_live_parameters": 3e7, "stage3_prefetch_bucket_size": 3e7, "memory_efficient_linear": False, }, "bfloat16": {"enabled": True}, "gradient_clipping": 1.0, } torch.distributed.barrier() global_rank = torch.distributed.get_rank() ##### DEFINE model, dataset, sampler, dataloader, optim, schedular N = 256 with deepspeed.zero.Init(enabled=(zero_stage == 3)): d3pm = D3PM( DDiT_Llama(N, dim=768, n_layers=8), 1000, num_classes=N, hybrid_loss_coeff=0.0, ).cuda() total_params = sum(p.numel() for p in d3pm.parameters()) size_in_bytes = total_params * 4 size_in_gb = size_in_bytes / (1024**3) logger.info( f"Model Size: {size_in_bytes}, {size_in_gb} GB, Total Param Count: {total_params / 1e6} M" ) dataset = WikiTextDataset(max_seq_length=max_seq_length, debug=False) train_sampler = ( RandomSampler(dataset) if local_rank == -1 else DistributedSampler(dataset, seed=seed) ) dataloader = DataLoader( dataset, collate_fn=default_data_collator, sampler=train_sampler, batch_size=per_device_train_batch_size, ) optimizer = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=learning_rate) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=100, num_training_steps=num_train_epochs * math.ceil(len(dataloader)), ) d3pm.train() model_engine, optimizer, _, lr_scheduler = deepspeed.initialize( model=d3pm, config=ds_config, lr_scheduler=lr_scheduler, optimizer=optimizer ) global_step = 0 ##### actual training loop if global_rank == 0: wandb.init( project="d3pm_wiki", name=run_name, config={ "N": N, "max_seq_length": max_seq_length, "num_train_epochs": num_train_epochs, "learning_rate": learning_rate, "offload": offload, "train_batch_size": train_batch_size, "per_device_train_batch_size": per_device_train_batch_size, "zero_stage": zero_stage, "seed": seed, }, ) for i in range(num_train_epochs): pbar = tqdm(dataloader) loss_ema = None for x in pbar: x = x["input_ids"].to(model_engine.device) # discritize x to N bins loss, info = model_engine(x) model_engine.backward(loss) model_engine.step() get_accelerator().empty_cache() norm = model_engine.get_global_grad_norm() if global_step % 10 == 0: if global_rank == 0: wandb.log({"train_loss": loss, "train_grad_norm": norm}) pbar.set_description( f"norm: {norm}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}" ) global_step += 1 if global_step % 600 == 1: d3pm.eval() with torch.no_grad(): init_noise = torch.randint(0, N, (16, max_seq_length)).cuda() outputs = d3pm.sample_with_image_sequence( init_noise, None, stride=40 ) gen_outputs = [] total = 0 # back to sentence, byte to utf-8 for _i in range(16): sent = outputs[-1][_i].cpu().tolist() correctly_parsed = True try: sent = b"".join([bytes([i]) for i in sent]).decode("utf-8") except: # if there is error, just unicodec correctly_parsed = False sent = "".join([chr(i) for i in sent]) sent = ( f"[{_i}] Sample Correctly parsed: " + str(correctly_parsed) + "\n" + sent ) total += 1 if correctly_parsed else 0 gen_outputs.append(sent) print(sent) model_engine.train() # make a nice html to show the generated outputs html_formatted = "
".join(gen_outputs) # log text if global_rank == 0: wandb.log( { "generated_text": wandb.Html(html_formatted), "correctly_parsed": total, } ) if global_step % 3000 == 1: save_zero_three_model( model_engine, global_rank, "./ckpt", zero_stage=zero_stage ) print(f"Model saved at {global_step}") if __name__ == "__main__": main() ================================================ FILE: readme.md ================================================

large large

# Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch

small small

**Special thanks to [fal.ai](https://fal.ai/) for the compute resources for this project.** This is minimal (400 LOC), but fully faithful implementation of a D3PM [Structured Denoising Diffusion Models in Discrete State-Spaces](https://arxiv.org/abs/2107.03006). in pytorch. I have tried to keep the code as simple as possible with much comments and explanation that is somewhat lacking on the original jax implementation, so that it is easy to understand. As far as I know, this is the first, faithful reimplementation of D3PM in pytorch. (Please correct me if I am wrong). Of course, this implementation was heavily based on the [official implementation](https://github.com/google-research/google-research/tree/master/d3pm/images). Difference between this implementation and the official implementation: * This one has conditional sampling, so as you can see, generations are class-conditioned. * This one uses rather different/simple model architecture. * This one simplfies the official implementation very very much, so it is 400 LOC. * This one does not use truncated logistic reparameterization, but you can use that if you wish. * Only has uniform sample with inverse-linear beta scheudule, but you can change that with couple loc as well. ## Usage Following is completely self-contained example. ```bash python d3pm_runner.py ``` Following uses dit.py, for CIFAR-10 dataset. ```bash python d3pm_runner_cifar.py ``` ## Requirements Install torch, torchvision, pillow, tqdm ```bash pip install torch torchvision pillow tqdm ``` ## Citation This implementation: ```bibtex @misc{d3pm_pytorch, author={Simo Ryu}, title={Minimal Implementation of a D3PM (Structured Denoising Diffusion Models in Discrete State-Spaces), in pytorch}, year={2024}, howpublished={\url{https://github.com/cloneofsimo/d3pm}} } ``` Original Paper: ```bibtex @article{austin2021structured, title={Structured denoising diffusion models in discrete state-spaces}, author={Austin, Jacob and Johnson, Daniel D and Ho, Jonathan and Tarlow, Daniel and Van Den Berg, Rianne}, journal={Advances in Neural Information Processing Systems}, volume={34}, pages={17981--17993}, year={2021} } ``` ================================================ FILE: run_multigpu.sh ================================================ export WORLD_SIZE=$(nvidia-smi -L | wc -l) deepspeed --num_gpus $WORLD_SIZE lm_deepspeed.py --learning_rate 1e-4 ================================================ FILE: test.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "\n", "def get_logits_from_logistic_pars(loc, log_scale, num_classes = 10):\n", " loc = loc.unsqueeze(-1)\n", " log_scale = log_scale.unsqueeze(-1)\n", "\n", " inv_scale = (-log_scale + 2.0).exp()\n", " \n", " bin_width = 2.0 / (num_classes - 1)\n", " bin_centers = torch.linspace(-1.0, 1.0, num_classes).to(loc.device)\n", " bin_centers = bin_centers.reshape((*loc.shape, num_classes))\n", " bin_centers = bin_centers - loc\n", " log_cdf_min = -torch.log1p((-inv_scale * (bin_centers - 0.5 * bin_width)).exp())\n", " log_cdf_plus = -torch.log1p((-inv_scale * (bin_centers + 0.5 * bin_width)).exp())\n", " logits = log_minus_exp(log_cdf_plus, log_cdf_min)\n", " return logits\n", "\n", "def log_minus_exp(a, b, epsilon=1.e-6):\n", " return a + torch.log1p(-torch.exp(b - a) + epsilon)\n", "\n", "\n", "\n", "\n", "\n", "loc = torch.randn(1, 1, 1, 1).clip(-1, 1)\n", "log_scale = torch.randn(1, 1, 1, 1)\n", "\n", "num_classes = 10\n", "logits = get_logits_from_logistic_pars(loc, log_scale, num_classes)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.4609, -1.4102, -1.8315, -2.5864, -3.5119, -4.5085, -5.5319, -6.5647,\n", " -7.6002, -8.6343]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits[0,0,0,0,:]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGiCAYAAADJO+2bAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABQWUlEQVR4nO3deVxU5f4H8M+wzLAIw76og4IbbriLoKWVaaalVlZmKmau3MzsVnhvZrahV+t2f+aWuWUZmblUZqallor7rqCiKIuACDIDCAPMPL8/zClSEXTgOQOf9+t1Xq/m8JzhezjCfDrnWVRCCAEiIiIiBbKTXQARERHR7TCoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYkkPKvn5+Zg8eTIaNWoEZ2dnREZGYv/+/bLLIiIiIgWQHlRefPFFbNmyBStXrsTx48fRp08f9O7dG+np6bJLIyIiIslUMhclLCoqgpubGzZs2ID+/ftb9nfq1An9+vXDe++9J6s0IiIiUgAHmd+8rKwMJpMJTk5O5fY7Oztj586dtzzGaDTCaDRaXpvNZuTm5sLb2xsqlapa6yUiIiLrEEIgPz8f9evXh51dBQ94hGQRERGiZ8+eIj09XZSVlYmVK1cKOzs70bx581u2nz59ugDAjRs3bty4casFW2pqaoU5QeqjHwA4d+4cXnjhBfz222+wt7dHx44d0bx5cxw8eBAJCQk3tf/7HRW9Xo+goCCkpqbC3d29JksnIiKiu2QwGKDT6ZCXlwetVnvbdlIf/QBAkyZNsGPHDhQWFsJgMCAwMBDPPPMMQkJCbtleo9FAo9HctN/d3Z1BhYiIyMbcqduG9FE/N7i6uiIwMBBXr17F5s2bMXDgQNklERERkWTS76hs3rwZQgi0aNECSUlJeO211xAaGopRo0bJLo2IiIgkk35HRa/XIzo6GqGhoRgxYgR69OiBzZs3w9HRUXZpREREJJn0zrT3ymAwQKvVQq/Xs48KERGRjajs57f0OypEREREt8OgQkRERIrFoEJERESKxaBCREREisWgQkRERIrFoEJERESKxaBCREREisWgQkRERIolNaiYTCZMmzYNwcHBcHZ2RpMmTfDuu+/CxuegIyIiIiuRutbPrFmzsGDBAqxYsQKtW7fGgQMHMGrUKGi1WkyaNElmaURERKQAUoPK7t27MXDgQPTv3x8A0LhxY3z11VfYt2+fzLKIiIhIIaQ++omMjMQvv/yCM2fOAACOHj2KnTt3ol+/frc9xmg0wmAwlNuIiIiodpJ6RyUmJgYGgwGhoaGwt7eHyWTC+++/j2HDht32mNjYWMyYMaMGqyQiIiJZpN5RWb16Nb788kusWrUKhw4dwooVKzBnzhysWLHitsdMnToVer3esqWmptZgxURERFSTVELiEBudToeYmBhER0db9r333nv44osvkJiYWKn3qOwy0URERKQclf38lnpH5dq1a7CzK1+Cvb09zGazpIqIiIhISaT2UXnsscfw/vvvIygoCK1bt8bhw4fx0Ucf4YUXXpBZFhERESmE1Ec/+fn5mDZtGtatW4fLly+jfv36GDp0KN566y2o1epKvQcf/RAREdmeyn5+Sw0q1sCgQkREZHtsoo8KERERUUUYVIiIiEixGFSIiIhIsRhUiIiISLEYVIiIiEixGFSIiIhIsRhUiIiISLEYVIiIiEixGFSIiIhIsaQHlcaNG0OlUt20/XVFZSIiIqqbpC5KCAD79++HyWSyvD5x4gQefvhhDBkyRGJVREREpATSg4qvr2+51zNnzkSTJk3Qs2dPSRURERGRUkgPKn9VUlKCL774AlOmTIFKpbplG6PRCKPRaHltMBhqqjwiIiKqYdL7qPzV+vXrkZeXh6ioqNu2iY2NhVartWw6na7mCiQiIqIapRJCCNlF3NC3b1+o1Wp8//33t21zqzsqOp3ujstEExERkXIYDAZotdo7fn4r5tHPxYsXsXXrVqxdu7bCdhqNBhqNpoaqIiIiIpkU8+hn2bJl8PPzQ//+/WWXQkRERAqhiKBiNpuxbNkyjBw5Eg4OirnJQ0RERJIpIqhs3boVKSkpeOGFF2SXQkRERAqiiNsXffr0gYL69BIREZFCKOKOChEREdGtMKgQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYkkPKunp6Xj++efh7e0NZ2dntG3bFgcOHJBdFhERESmA1Jlpr169iu7du+OBBx7Apk2b4Ovri7Nnz8LT01NmWURERKQQUoPKrFmzoNPpsGzZMsu+4OBgiRURERGRkkh99PPdd9+hc+fOGDJkCPz8/NChQwcsXry4wmOMRiMMBkO5jYiIiGonqUHl/PnzWLBgAZo1a4bNmzdjwoQJmDRpElasWHHbY2JjY6HVai2bTqerwYqJiIioJqmExGWL1Wo1OnfujN27d1v2TZo0Cfv370d8fPwtjzEajTAajZbXBoMBOp0Oer0e7u7u1V4zERER3TuDwQCtVnvHz2+pd1QCAwPRqlWrcvtatmyJlJSU2x6j0Wjg7u5ebiMiIqLaSWpQ6d69O06fPl1u35kzZ9CoUSNJFREREZGSSA0qr7zyCvbs2YMPPvgASUlJWLVqFT799FNER0fLLIuIiIgUQmpQ6dKlC9atW4evvvoKbdq0wbvvvouPP/4Yw4YNk1kWERERKYTUzrTWUNnOOERERKQcNtGZloiIiKgiDCpERESkWAwqREREpFgMKkRERKRYDCpERESkWAwqREREpFgMKkRERKRYDCpERESkWAwqREREpFjSg8rbb78NlUpVbgsNDZVdFhERESmAg+wCAKB169bYunWr5bWDgyLKIiIiIskUkQgcHBwQEBBQqbZGoxFGo9Hy2mAwVFdZREREJJn0Rz8AcPbsWdSvXx8hISEYNmwYUlJSbts2NjYWWq3Wsul0uhqslIiIiGqS9NWTN23ahIKCArRo0QIZGRmYMWMG0tPTceLECbi5ud3U/lZ3VHQ6HVdPJiIisiGVXT1ZelD5u7y8PDRq1AgfffQRRo8efcf2lT1RIiIiUo7Kfn4r4tHPX3l4eKB58+ZISkqSXQoRERFJprigUlBQgHPnziEwMFB2KURERCSZ9KDyz3/+Ezt27MCFCxewe/duDB48GPb29hg6dKjs0oiIiEgy6cOT09LSMHToUOTk5MDX1xc9evTAnj174OvrK7s0IiIikkx6UImLi5NdAhERESmU9Ec/RERERLfDoEJERESKxaBCREREisWgQkRERIrFoEJERESKxaBCREREisWgQkRERIrFoEJERESKpaigMnPmTKhUKkyePFl2KURERKQAigkq+/fvx6JFixAWFia7FCIiIlIIRQSVgoICDBs2DIsXL4anp6fscoiIiEghFBFUoqOj0b9/f/Tu3fuObY1GIwwGQ7mNiIiIaidFLEp46NAh7N+/v1LtY2NjMWPGjGquioiIiJRA6h2V1NRUvPzyy/jyyy/h5ORUqWOmTp0KvV5v2VJTU6u5SiIiIpJFJYQQsr75+vXrMXjwYNjb21v2mUwmqFQq2NnZwWg0lvvarRgMBmi1Wuj1eri7u1d3yURERGQFlf38lvro56GHHsLx48fL7Rs1ahRCQ0Pxxhtv3DGkEBERUe0mNai4ubmhTZs25fa5urrC29v7pv1ERERU9yhi1A8RERHRrUgf9fN327dvl10CERERKQTvqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWJJDyoLFixAWFgY3N3d4e7ujoiICGzatEl2WURERKQA0oNKw4YNMXPmTBw8eBAHDhzAgw8+iIEDB+LkyZOySyMiIiLJVEIIIbuIv/Py8sLs2bMxevTom75mNBphNBotrw0GA3Q6HfR6Pdzd3WuyTCIiIrpLBoMBWq32jp/f0u+o/JXJZEJcXBwKCwsRERFxyzaxsbHQarWWTafT1XCVREREVFMUcUfl+PHjiIiIQHFxMerVq4dVq1bh0UcfvWVb3lEhIiKyfZW9o+JQgzXdVosWLXDkyBHo9XqsWbMGI0eOxI4dO9CqVaub2mo0Gmg0GglVEhERUU1TxB2Vv+vduzeaNGmCRYsW3bFtZRMZERERKYdN9lG5wWw2l3u8Q0RERHWT9Ec/U6dORb9+/RAUFIT8/HysWrUK27dvx+bNm2WXRkRERJJJDyqXL1/GiBEjkJGRAa1Wi7CwMGzevBkPP/yw7NKIiIhIMulBZcmSJbJLICIiIoVSZB8VIiIiIoBBhYiIiBSMQYWIiIgUi0GFiIiIFItBhYiIiBSLQYWIiIgUi0GFiIiIFItBhYiIiBRLelCJjY1Fly5d4ObmBj8/PwwaNAinT5+WXRYREREpgPSgsmPHDkRHR2PPnj3YsmULSktL0adPHxQWFsoujYiIiCRTCSGE7CL+Kjs7G35+ftixYwfuv//+O7av7DLRREREpByV/fyWvtbP3+n1egCAl5fXLb9uNBphNBotrw0GQ43URURERDVP+qOfvzKbzZg8eTK6d++ONm3a3LJNbGwstFqtZdPpdDVcJREREdUURT36mTBhAjZt2oSdO3eiYcOGt2xzqzsqOp2Oj36IiIhsiM09+vnHP/6BH374Ab/99tttQwoAaDQaaDSaGqyMiIiIZJEeVIQQeOmll7Bu3Tps374dwcHBsksiIiIihZAeVKKjo7Fq1Sps2LABbm5uyMzMBABotVo4OztLro6IiIhkkt5HRaVS3XL/smXLEBUVdcfjOTyZiIjI9thMHxUF9eUlIiIihVHU8GQiIiKiv2JQISIiIsViUCEiIiLFYlAhIiIixWJQISIiIsViUCEiIiLFYlAhIiIixWJQISIiIsViUCEiIiLFkh5UfvvtNzz22GOoX78+VCoV1q9fL7skIiIiUgjpQaWwsBDt2rXDvHnzZJdCRERECiN9rZ9+/fqhX79+lW5vNBphNBotrw0GQ3WURURERAog/Y5KVcXGxkKr1Vo2nU4nuyQiIiKqJjYXVKZOnQq9Xm/ZUlNTZZdERERE1UT6o5+q0mg00Gg0sssgIiKiGmBzd1SIiIio7mBQISIiIsWS/uinoKAASUlJltfJyck4cuQIvLy8EBQUJLEyIiIikk16UDlw4AAeeOABy+spU6YAAEaOHInly5dLqoqIiIiUQHpQ6dWrF4QQsssgIiIiBWIfFSIiIlIsBhUiIiJSLAYVIiIiUiwGFSIiIlIsBhUiIiJSLAYVIiIiUiwGFSIiIlIsBhUiIiJSLEUElXnz5qFx48ZwcnJCeHg49u3bJ7skIiIiUgDpQeXrr7/GlClTMH36dBw6dAjt2rVD3759cfnyZdmlERERkWTSg8pHH32EMWPGYNSoUWjVqhUWLlwIFxcXLF26VHZpREREJJnUoFJSUoKDBw+id+/eln12dnbo3bs34uPjb3mM0WiEwWAotxEREVHtJDWoXLlyBSaTCf7+/uX2+/v7IzMz85bHxMbGQqvVWjadTlcTpRIREZEE0h/9VNXUqVOh1+stW2pqquySiIiIqJo4yPzmPj4+sLe3R1ZWVrn9WVlZCAgIuOUxGo0GGo2mJsojIiIiyaTeUVGr1ejUqRN++eUXyz6z2YxffvkFEREREisjIiIiJZB6RwUApkyZgpEjR6Jz587o2rUrPv74YxQWFmLUqFGySyMiIiLJpAeVZ555BtnZ2XjrrbeQmZmJ9u3b46effrqpgy0RERHVPSohhJBdxL0wGAzQarXQ6/Vwd3eXXQ4RERFVQmU/v21u1A8RERHVHQwqREREpFgMKkRERKRYDCpERESkWAwqREREpFgMKkRERKRY0udRISICgJIyM/RFpZbN/LeZE+ztVNA6O1o2R3v+fxZRXcCgQkQ1QgiBLIMRJy/pkXylECm5165vOdeQaSjGtRJTld7PVW2PAK0Tgrxc0MjbFUFeLgj2cUXrBu7wc3OqprMgoprGoEJE1aK41IT9F3Jx8OJVHE/T41i6Htn5xgqPUakAdydHuDs7wNGu/B2TEpMZhqJSGIrLAACFJSacyy7EuexCANnl2ga4O6FtQy3CGmjRqZEnOjX2hMbB3qrnR0Q1Q2pQef/997Fx40YcOXIEarUaeXl5MsshontgNgucvGTA70nZ2JV0BfsvXEVJmblcG3s7FZr51UMT33oI8na5fjfEywX1PZzh4eIINydH2NupKvw+JrOAoagUeUWlSL9ahIu5hZY7M0mXC5CUXYBMQzEyTxVjy6nrK7M7Odqha7A3ejT1xn3NfBEa4AaVquLvQ0TKIHUK/enTp8PDwwNpaWlYsmTJXQUVTqFPJI/ZLHA49So2HsvEphMZyNAXl/t6oNYJ3UK80a6hFm0beqBVoDuc1dV7Z6PQWIZTGQYcS9PjaGoe9pzPweW/3cnReTnj0TaBeLRtIMIaahlaiCSo7Oe3Itb6Wb58OSZPnlypoGI0GmE0/vlHx2AwQKfTMagQ1aAzWfn4en8qNh7LQKbhz3DiqrZHRBMf3NfMBz2a+SDEx1V6CBBC4OzlAvx+9gp2ns1G/PkcFJf+eaenoaczBoTVx9OdGyLEt57ESonqlsoGFZvroxIbG4sZM2bILoOozik0lmHjsQzE7U/BoZQ8y/56Ggf0bumHR9sG4v7mvnByVFZfEJVKheb+bmju74bRPYJxraQM209nY+PxDPyacBlpV4uwcMc5LNxxDuHBXni2qw792gQq7jyI6ireUSGiCl3MKcTSncn49lA6CozXO7La26nwUKgfnurUUJHhpLKKSkzYdvoy1hxMw/bTl2H+46+hu5MDnumiQ1T3YDTwcJZbJFEtJe2OSkxMDGbNmlVhm4SEBISGht7V+2s0Gmg0mrs6logq7+DFXCz+LRmbT2Xixv/ONPZ2wTNdgvBkpwa1Ygiws9oej7a93lflUl4R1hxMw9f7U5GeV4TFvydj6a4L6N82EC/eF4ywhh6yyyWqk6x+RyU7Oxs5OTkVtgkJCYFarba8rsodlb9jZ1oi6xFCYMeZbPzfL2fLPd7p1cIXo3sEo0dTH+l9Tqqb2Syw/cxlfPZ7Mnaf+/NvWbcQL0x6qBkiQrxr/c+AqCZIu6Pi6+sLX19fa78tEVWjGwHl461ncSQ1DwCgtrfD4A4NMPq+YDT3d5NbYA2ys1PhwVB/PBjqjxPpeizZmYzvj17CnvO52HN+L7oGe2Fy72aIbOIju1SiOkFqZ9qUlBTk5uYiJSUFJpMJR44cAQA0bdoU9eqx9z1RTdh59grm/HzaElCcHO0wvFsjjLk/pFY83rkXbRpo8d9n2uP1R1pg4fZz+GpfKvYl5+K5xXsRHuyFf/ZtgS6NvWSXSVSrSe1MGxUVhRUrVty0f9u2bejVq1el3oOPfojuTkKGAbGbEvHbmeuzut4IKGPvbwJfN/YDu5UMfREWbD+HuH2pKDFdH+Lcp5U/3ugXiiYc2kxUJTY1j8q9YFAhqppMfTE+/Pk01hxKgxCAo70Kw8IbIfqBpgwolZShL8LcX5Pw9f5UmMwC9nYqDO2qw8sPNefPkKiSGFSIqJziUhMW/3Ye87YnWSY86982EK8/0gKNvF0lV2ebki7nY+amRGxNuAzg+oR3L/duhqjIYKgduLozUUUYVIjIYuupLLzzwymk5F4DAHRp7Il/PdoSHYI8JVdWO+w5n4MPfkzAsTQ9AKCJrytmPN4GPZqxwy3R7TCoEBEuXCnEOz+cwq+J1/+PP8DdCf/u3xIDwgI5xNbKzGaBNYfSMGtTInIKSwAA/doE4M0BrThpHNEtMKgQ1WGlJjM+/e08/vfLWZSUmeFor8LoHiF46cGmcNXY3MoZNkVfVIr/bjmDz+MvwCwAF7U9Xu3TAlGRje+4MjRRXcKgQlRHHUq5iqnfHsfprHwAQI+mPpgxsDVHpdSwhAwD3tpwAvsvXAUAhDXU4oPBbdGmgVZyZUTKwKBCVMcUGMvwn58SsXLPRQgBeLmq8daAVhjYvj4f80hiNgvE7U9F7KYE5BeXwd5OhRd7BGNy7+ZwVtvm+khE1sKgQlSH7Eq6gtfXHEN6XhEA4MmODfHv/i3h5aq+w5FUEy4bijHj+1PYeDwDABDs44o5Q8LQqREni6O6i0GFqA4oMJYh9scEfLk3BQDQ0NMZM58I42gThdp6Kgv/Xn8cWQYjVCpgdPdg/LNvC5tdfZroXlT281vaQP8LFy5g9OjRCA4OhrOzM5o0aYLp06ejpKREVklENmX3uSt45OPfLCFleLdG2Dz5foYUBevdyh8/T+6JJzs2hBDAZzuT8ej/fsfBi1dll0akWNK6/ycmJsJsNmPRokVo2rQpTpw4gTFjxqCwsBBz5syRVRaR4hWXmjB782ks2ZkM4PpdlP88GYbIpgwotkDr4ogPn26H/mEBiPn2OM5fKcSQhbsR/UBTTHqoGRztOVEc0V8p6tHP7NmzsWDBApw/f77Sx/DRD9Ulpy4ZMPnrwziTVQAAeC48CP96tCXqccixTdJfK8Xb35/EusPpAIC2fyyC2NSPI7So9lP8o59b0ev18PKquHOZ0WiEwWAotxHVdiazwMId5zBw3k6cySqATz01lkZ1xgeD2zKk2DCtiyP++0x7zHuuI7TOjjierseAub/j8/gLUND/QxJJpZigkpSUhLlz52LcuHEVtouNjYVWq7VsOp2uhiokkiNTX4xhn+3BzE2JKDUJPNzKH5sn348HQ/1ll0ZW0j8sEJsn34/7mvmguNSMtzacxOgVB5BTYJRdGpF0Vn/0ExMTg1mzZlXYJiEhAaGhoZbX6enp6NmzJ3r16oXPPvuswmONRiOMxj9/eQ0GA3Q6HR/9UK205VQWXltzFHnXSuGitsdbA1rhmS46zotSS5nNAiviLyB2UyJKyszwddPgv0+3ZwdpqpWkDU/Ozs5GTk5OhW1CQkKgVl+f3+HSpUvo1asXunXrhuXLl8POrmo3edhHhWqj4lITYn9MwIr4iwCANg3c8X/PdkAIZ5etExIyDJj01WGcvVwAlQoYd38TvNqnOTvaUq1iE/OopKen44EHHkCnTp3wxRdfwN6+6nMJMKhQbZN0uQD/WHUIiZnXp8Afc9/1uTY0Dpxroy4pKjHhvY2nLMPP2zXU4pPnOkLn5SK5MiLrUHxQSU9PR69evdCoUSOsWLGiXEgJCAio9PswqFBtsvZQGt5cfwLXSkzwqafGnCHt0KuFn+yySKKfTmTgjW+PQ19UCjcnB8x+KgyPtAmUXRbRPVN8UFm+fDlGjRp1y69VpSQGFaoNikpMeGvDCXxzMA0AENnEGx8/2x5+bk6SKyMlSM8rwkurDuFQSh4AYGREI/yrf0veZSObpvigYi0MKmTrki7nY+KXh3Am63p/hMkPNcc/HmwKezt2mKU/lZrMmPPzaSzacX2eqTYN3DHvuY5o5O0quTKiu2OT86gQ1TUbjqTjsbm7cCarAL5uGnz5Yjhe7t2MIYVu4mhvh6n9WmJpVGd4uDjiRLoBA+buxM8nM2WXRlStGFSIJDCWmTBt/Qm8HHcERaUmRDbxxo+T7kNkEw5DpYo9GOqPHyfdh45BHsgvLsPYlQcR+2MCykxm2aURVQsGFaIalnb1Gp5eGI+Ve64PPX7pwaZYOTocvm4ayZWRrajv4Yy4sRF4oXswAGDRb+fx3Gd7cdlQLLkyIutjUCGqQTvOZGPA3J04mqaH1tkRy6K64NU+Lfioh6pM7WCHtx5rhfnDOqKexgH7knPx6P/txN7zFc9jRWRrGFSIaoDZLDD3l7OIWrYPeddK0a6hFhsn9cADoRx6TPfm0baB+O4f3dHC3w1XCox47rO9+Oz381wriGoNBhWiaqYvKsWYzw/gwy1nIMT1FY9Xj49AQ09O3EXWEeJbD+uiIzGwfX2YzALvbUzAS18dRqGxTHZpRPeMy64SVaPETAPGrTyIiznXoHaww3uD2uDpzlxIk6zPRe2Aj59pjw46D7y3MQE/HMvAmax8LHy+E5deIJvGOypE1eT7o5cweN5uXMy5hoaezlg7IZIhhaqVSqVCVPdgxI3tBj83Dc5kFWDgJ7vwS0KW7NKI7prUoPL4448jKCgITk5OCAwMxPDhw3Hp0iWZJRHdszKTGR/8eP3We1GpCfc188H3/+iBNg20skujOqJzYy/8MKkHujT2RL6xDKNXHMDHW8/AbGa/FbI9UoPKAw88gNWrV+P06dP49ttvce7cOTz11FMySyK6J7mFJRi5bB8+/e367KETejXB8lFd4emqllwZ1TV+bk748sVuGBHRCADw8dazGLvyIAzFpZIrI6oaRU2h/91332HQoEEwGo1wdHSs1DGcQp+U4uQlPcZ+fhDpeUVwUdtj9lPt0D+Mi8eRfN8cSMW/159ASZkZIT6u+HREZzT1Y78VksvmptDPzc3Fl19+icjIyApDitFohMFgKLcRyfb90Ut4csFupOcVoZG3C9ZN7M6QQooxpLMOa8ZHIFDrhPNXCjF4HvutkO2QHlTeeOMNuLq6wtvbGykpKdiwYUOF7WNjY6HVai2bTsfOiSSPySww66dEvPTVYRSXmtGzuS++i+6BFgFusksjKiesoQe+f6kHujb2Qr6xDC9+fgCf/HqW862Q4ln90U9MTAxmzZpVYZuEhASEhoYCAK5cuYLc3FxcvHgRM2bMgFarxQ8//ACV6tYzdRqNRhiNRstrg8EAnU7HRz9U4/RFpXg57jC2n84GAIzrGYLX+4ZylllStJIyM9794ZRlCYdH2wZg9lPt4KrhbBVUsyr76MfqQSU7Oxs5ORVP4RwSEgK1+ubOhWlpadDpdNi9ezciIiIq9f3YR4VkOJddgDErDuD8lUI4Odph1pNhGNi+geyyiCotbl8Kpm04gVKTQGiAGxaP6AydFychpJpT2c9vq0doX19f+Pr63tWxZvP11T//eseESGm2nb6MSasOI99YhgYezlg0vBOHHpPNebZrEJr518O4lYeQmJmPgfN2Yf6wjugW4i27NKJypI362bt3L/bv348ePXrA09MT586dw7Rp05CVlYWTJ09Co6ncSrK8o0I1RQiBT387j5k/JUIIoGtjL8x/viN86nHVY7JdGfoijP38II6n6+Fgp8L0x1tjeLdGssuiOkDxo35cXFywdu1aPPTQQ2jRogVGjx6NsLAw7Nixo9IhhaimFJea8MrXRxC76XpIGdo1CF+8GM6QQjYvUOuMb8ZH4PF29VFmFpi2/gT+te44SsrMsksjAqCweVTuBu+oUHXLMhRj7OcHcDRND3s7Fd5+rBWe79both2+iWyREAILd5zHfzZfD+PhwV5Y8HwneHGyQqomir+jQmQLjqXl4fFPduJomh4eLo5YOborhkc0ZkihWkelUmFCryb4bERn1NM4YG9yLgbO24nTmfmyS6M6jkGF6Da+O3oJQxbGI8tgRDO/etgQ3R2RTXxkl0VUrR5q6Y+1EyMR5OWC1NwiPDF/F7ac4uRwJA+DCtHfmM0CH/58GpO+OgxjmRkPhfph7cRINPJ2lV0aUY1o7u+GDdHd0S3EC4UlJoxdeQDztydxcjiSgkGF6C8KjWWY8OVBzP01CcD1Sdw+HdEZbk6VW3uKqLbwdFVj5ehwPN8tCEIA//npNF5dfRTFpSbZpVEdw6BC9IdLeUUYsjAem09mQW1vhw+HtMPUfi050yzVWY72dnhvUFu8O7A17O1UWHs4HUMX70F2Pue6oprDoEIE4FDKVTz+yS6cyjDAp54aX40Nx5OdGsoui0gRhkc0xopRXeHu5IDDKXkY+MlOnLykl10W1REMKlTnrT+cjmc/3YMrBUaEBrhhfXR3dGrkJbssIkXp0cwH66O7I8THFZf0xXhqQTw2n8yUXRbVAQwqVGeZzQJzNp/G5K+PoKTMjIdb+ePbCZFo6Mn1TohuJcS3HtZN7I77mvmgqNSE8V8cxILt59jJlqoVgwrVSddKyhC96hA+2Xa90+yEXk2w6PlOXEGW6A60Lo5YFtUFIyIaQQhg1k+JePWbozCWsZMtVQ9FBBWj0Yj27dtDpVLhyJEjssuhWi5TX4ynF8Vj04lMONqrMGdIO7zxSCjs2GmWqFIc7O3wzsA2eOdGJ9tD6Ri2eC9yCtjJlqxPEUHl9ddfR/369WWXQXXA8TQ9Bs7biRPpBni5qrFqTDc8xU6zRHdlRERjLIvqAjcnBxy4eBUD5+3CmSzOZEvWJT2obNq0CT///DPmzJkjuxSq5TYdz8CQRbstM82un9gdXRqz0yzRvbi/uS/WTYxEI28XpF0twhPzd2P76cuyy6JaRGpQycrKwpgxY7By5Uq4uFSuA6PRaITBYCi3EVVECIF525Iw4ctDKC41o2dzX3w7MRJB3uw0S2QNTf3csH5id3QN9kKBsQwvLN+P5buS2cmWrEJaUBFCICoqCuPHj0fnzp0rfVxsbCy0Wq1l0+l01Vgl2TpjmQmvrj6K2ZtPAwCiIhtjycjOcOdMs0RW5emqxhejw/F054YwC+Dt709h2oYTKDWZZZdGNs7qQSUmJgYqlarCLTExEXPnzkV+fj6mTp1apfefOnUq9Hq9ZUtNTbX2KVAtkVNgxLDFe7H2cDrs7VR4d1AbvP14azjYS3/iSVQrqR3sMOvJMEztFwqVCvhiTwpeWL4f+qJS2aWRDVMJK9+by87ORk5OToVtQkJC8PTTT+P777+HSvXnSAuTyQR7e3sMGzYMK1asqNT3MxgM0Gq10Ov1cHd3v6faqfY4m5WPF1bsR2puEdycHDB/WEfc18xXdllEdcbPJzMx+esjuFZiQlO/elg6sgsft1I5lf38tnpQqayUlJRy/UsuXbqEvn37Ys2aNQgPD0fDhpUbicGgQn/3+9lsTPzyEPKLyxDk5YKlUZ3R1M9NdllEdc6JdD1eXHEAmYZieLmqsWh4J3ZgJwvFB5W/u3DhAoKDg3H48GG0b9++0scxqNBffbHnIqZ/dxIms0CXxp5YNLwzvFzVsssiqrOyDMV4ccUBHE/XQ21vh5lPtsUTHTklAFX+85sP66lWMJkFZnx/Em+uPwGTWeCJDg3wxYvhDClEkvm7O+Hrcd3wSOsAlJjMmLL6KD78+TTMZkX8PzLZAMXcUblbvKNCBcYyvPzVYfySeH3uhn/2aY7oB5qW6/9ERHKZzQKzfz6NBdvPAQAGhAVizpB2cHK0l1wZyVLZz28ubEI27VJeEV5Yvh+JmfnQONjhw6fbYUAYZzkmUho7OxXeeCQUwT6u+Nfa4/jhWAbS84rw6fDO8HXTyC6PFIyPfshmHUvLw8B5u5CYmQ+fehrEje3GkEKkcE931mHl6HBonR1xOCUPgzjtPt0BgwrZpJ9OZODpRfHIzjeihb8b1kdHokOQp+yyiKgSIpp4Y93ESDT2dkF6XhGenL8bO85kyy6LFIpBhWyKEAILd5zD+C/+nA5/zYQINPTk/AxEtiTEtx7WTeyO8GAv5P8x7f7KPRdll0UKxKBCNqPUZEbMt8cxc1MiAGBERCMsGdkZbpwOn8gmebqqsXJ0OJ7s2BAms8C09SfwzvenYOKIIPoLdqYlm6C/VoqJqw5iV1IO7FTAWwNaIap7sOyyiOgeqR3sMGdIGEJ8XTF782ks3ZWMlNxC/O/ZDnDV8COKeEeFbEBKzjU8sWAXdiXlwFVtj89GdmZIIapFVCoVoh9oik+e6wCNgx22JlzGkIXxyNAXyS6NFIBBhRTt4MVcDJq/C+eyCxGodcI34yPxYKi/7LKIqBoMCKuPr8Z2g089NU5lGDBo3i6cSNfLLoskkxpUGjdufNPKyjNnzpRZEinIhiPpGLp4L3ILS9C2gRYborujVX1O6kdUm3UM8sS6id3RzK8esgxGDFkYjy2nsmSXRRJJv6PyzjvvICMjw7K99NJLsksiyYQQ+L9fzuLluCMoKTOjTyt/fD2uG/zcnWSXRkQ1QOflgjUTInFfMx8UlZowduUBfPb7edj4ROp0l6QHFTc3NwQEBFg2V1dX2SWRRMYyE15dfRQfbTkDABh7fwgWPt8JLmp2qiOqS7TOjlga1QVDuwZBCOC9jQmYtuEEykxm2aVRDZO61k/jxo1RXFyM0tJSBAUF4bnnnsMrr7wCB4fbfygZjUYYjUbLa4PBAJ1Ox7V+aoGrhSUY98VB7EvOhb2dCu8MbI1h4Y1kl0VEEgkhsPj384jdlAghgJ7NffHJcx04LUEtYBOrJ0+aNAlxcXHYtm0bxo0bhw8++ACvv/56hcfExsZCq9VaNp1OV0PVUnVKvlKIJxbsxr7kXLhpHLAsqgtDChFBpVJh7P1NsGBYJzg52mHHmWwMWRiP9DyOCKorrH5HJSYmBrNmzaqwTUJCAkJDQ2/av3TpUowbNw4FBQXQaG69SBXvqNQ++5JzMXblAeRdK0UDD2csjeqCFgFusssiIoU5lpaH0SsOIDvfCF83DZaM7Iywhh6yy6K7VNk7KlYPKtnZ2cjJyamwTUhICNRq9U37T548iTZt2iAxMREtWrSo1Per7ImSMq07nIY31hxHicmMdg21WDyyM/zc2GmWiG4tPa8Io/9YMd3J0Q7/e7YD+rYOkF0W3YXKfn5bvYeir68vfH197+rYI0eOwM7ODn5+flauipRGCIGPt57F/345CwDo1yYAHz3dHs5qe8mVEZGSNfBwxjfjI/CPVYex40w2xn9xEP/q1xIv3hcMlUoluzyqBtKGUsTHx2Pv3r144IEH4Obmhvj4eLzyyit4/vnn4enJVXBrM2OZCTHfHse6w+kAgHE9Q/BG31DY2fGPDBHdmZuTI5aM7Iy3vz+JL/ak4P0fE5CcU4h3Hm8NB3vpg1nJyqQFFY1Gg7i4OLz99tswGo0IDg7GK6+8gilTpsgqiWrA1cISjFt5EPsuXB/Z896gNhjaNUh2WURkYxzs7fDuwDZo7O2K939MwKq9KUi7WoR5HBFU60gdnmwN7KNiO5KvFOKF5fuRfKUQbhoHzH++I+5rdnePCYmIbvj5ZCZejjuColITQgPcsCSqCxp4OMsui+7AJoYnU92xLzkXg+fvQvKVQjTwcMa3EyMZUojIKvq0DsDqcRHwddMgMTMfg+btwrG0PNllkZUwqFC1W3c4Dc9/thd510rRTueB9dHd0dyfw4+JyHraNtRifXR3hAa4ITvfiKcXxWPzyUzZZZEVMKhQtbk+sucMXvn6KEpMZvRrE4C4Md3g63brOXKIiO7FjRFBPZv7orjUjPFfHOQaQbUAgwpVC2OZCVNWH8XHW68PPx53fwjmPdeRw4+JqFrdGBH0fDeuEVRbcKU3srq8ayUYu/LPNXs4soeIatLfRwR9sScFqblFXCPIRvGOClnVhSuFGDz/zzV7lo/qwpBCRDVOpVLhxfuur77+1zWCLnGNIJvDoEJWs/9C+ZE9ayZwZA8RydX3FiOCjqfpZZdFVcCgQlax4Ug6hi3ei6vXShHWUIt10ZFcWJCIFCGs4fXRhi383XD5jxFBW05lyS6LKolBhe6JEAJzfzmLl+OOoMRkRt/W/vh6bAQXFiQiRWng4YxvJkTgvmY+KCo1YezKA1i2K1l2WVQJ0oPKxo0bER4eDmdnZ3h6emLQoEGyS6JKKikz45/fHMOHW84AAMbcF4z5wzpxZA8RKZK7kyOWRl3vNycEMOP7U5jOEUGKJ3XUz7fffosxY8bggw8+wIMPPoiysjKcOHFCZklUSfprpRj3xQHsOX99ZM+Mx1vj+W6NZJdFRFQhR3s7fDC4DYJ9XPDBj4lYEX8RqVeLMHdoB7hqOBBWiaSt9VNWVobGjRtjxowZGD16dKWPMxqNMBqNltcGgwE6nY5r/dSglJxriFq+D+ezC+GqtscnwzrigRZ+sssiIqqSTcczMPnrIzCWmdEq0B1Lo7ogQMvH1jVF8Wv9HDp0COnp6bCzs0OHDh0QGBiIfv363fGOSmxsLLRarWXT6XQ1VDEBwMGLVzF4/i6czy5EoNYJayZEMqQQkU3q1zYQcWO7waeeGqcyDBg0bxdOXTLILov+RlpQOX/+PADg7bffxptvvokffvgBnp6e6NWrF3Jzc2973NSpU6HX6y1bampqTZVc5208loGhi/cgp7AEbRq4Y310d7QM5F0sIrJdHYI8sW5idzTzq4dMQzGGLNyNbYmXZZdFf2H1oBITEwOVSlXhlpiYCLP5euelf//733jyySfRqVMnLFu2DCqVCt98881t31+j0cDd3b3cRtVLCIH525MQveoQSsrM6N3SD6vHRcDfnbdIicj26bxcsGZCJLo39UZhiQmjV+zHyvgLssuiP1i959Crr76KqKioCtuEhIQgIyMDANCqVSvLfo1Gg5CQEKSkpFi7LLpLpSYzpq0/gbj91+9cRUU2xrQBrWBvp5JcGRGR9WidHbEsqiveXH8cqw+kYdqGk7iYcw1TH23Jv3eSWT2o+Pr6wtf3zrORdurUCRqNBqdPn0aPHj0AAKWlpbhw4QIaNeLoESUwFJci+stD+P3sFdipgLcGtEJU92DZZRERVQu1gx1mPRmGRt6umL35ND7bmYyU3Gv4+Nn2cFFzRJAs0vqouLu7Y/z48Zg+fTp+/vlnnD59GhMmTAAADBkyRFZZ9Ie0q9fw1ILd+P3sFbio7bF4RGeGFCKq9VQqFaIfaIr/G9oBagc7/HwqC89+ugeX84tll1ZnSY2Is2fPhoODA4YPH46ioiKEh4fj119/haenp8yy6rxjaXkYveIAsvON8HPTYGlUF7RpoJVdFhFRjXm8XX3U1zphzOcHcCxNj8HzdmPZqC5o7s+lQWqatHlUrKWy47CpcjafzMTLcYdRXGpGaIAblkZ1QX0PZ9llERFJceFKIUYt34/kK4Vw0zhgwfOd0KOZj+yyagXFz6NCyiKEwGe/n8f4Lw6iuNSMns198c34CIYUIqrTGvu4Yu2ESHRt7IV8Yxmilu1D3D4O+KhJDCqEMpMZb204ifc2JkAIYFh4EJaM7Aw3J0fZpRERSefpqsbKF7tiUPv6KDMLxKw9jlk/JcJstukHEjaDQaWOKzCWYcznB7Byz0WoVMC/H22J9wa1gYM9/2kQEd2gcbDHf59pj0kPNQMALNh+Di/FHUZxqUlyZbUfP43qsEx9MZ5eGI9tp7OhcbDD/Oc6Ysz9IVCpOGcAEdHfqVQqTHm4OT4c0g6O9ipsPJaB5xbvQU6B8c4H011jUKmjTl36Y12LDAN86qnx9bgI9GsbKLssIiLFe7JTQ3z+QjjcnRxwKCUPg+fvxrnsAtll1VoMKnXQttOXMWThbmQaitHUrx7WTeyO9joP2WUREdmMiCbeWDuxO3RezkjJvYYn5u/GnvM5ssuqlRhU6piVey5i9PL9KCwxIbKJN76dEAmdl4vssoiIbM6N/9HrEOQBfVEphi/Zi7WH0mSXVeswqNQRJrPAez+cwrT1J2AWwJBODbF8VFdonTmyh4jobvnU0+CrMd3Qv20gSk0CU1YfxX+3nIGNT1GmKNKCyvbt22+7uvL+/ftllVUrFZWYMPHLg/hsZzIA4LW+LfCfp8KgdmBOJSK6V06O9pg7tAPG92wCAPjfL2fx6uqjMJZxRJA1SJuZtqSkBLm5ueX2TZs2Db/88gvOnTtX6ZEnnJm2YpfzizFmxQEcTdNDbW+H2UPCMLB9A9llERHVSl/tS8Gb60/AZBYID/bCouGd4OGill2WIil+Zlq1Wo2AgADL5u3tjQ0bNmDUqFEcHmslZ7LyMXjebhxN08PTxRFfjglnSCEiqkZDuwZhWVQX1NM4YG9yLp5YsBsXcwpll2XTFHPv/7vvvkNOTg5GjRpVYTuj0QiDwVBuo5vtSrqCJxfsRnpeEYJ9XLF2Ynd0aewluywiolrv/ua+WDMhAvW1TjifXYjB83fj4MWrssuyWYoJKkuWLEHfvn3RsGHDCtvFxsZCq9VaNp1OV0MV2o7V+1Mxcuk+5BeXoWtjL6ydEIlgH1fZZRER1RmhAe5YH90dbRtokVtYgqGL92DjsQzZZdkkqweVmJiY23aSvbElJiaWOyYtLQ2bN2/G6NGj7/j+U6dOhV6vt2ypqanWPgWbZTYLzN6ciNe/PYYys8DA9vWx8sWu8HTl81Eioprm5+6Er8d1Q++W/igpMyN61SEs2H6OI4KqyOqdabOzs5GTU/GkNyEhIVCr//zwfPfddzF37lykp6fD0bFqw2XZmfa64lITXltzDN8fvQQAmPRgU7zycHP29yEiksxkFnhv4yks23UBADC0qw7vDGwDxzq+plplP78drP2NfX194evrW+n2QggsW7YMI0aMqHJIoetyC0sw9vMDOHDxKhzsVIh9oi2GdOYjMSIiJbC3U2H6Y63RyMsF7/xwCl/tS0Xa1SLMG9YR7lyl/o6kx7lff/0VycnJePHFF2WXYpPOZxdg8PxdOHDxKtycHPD5C10ZUoiIFCiqezA+Hd4Zzo72+P3sFQxZEI/0vCLZZSme9KCyZMkSREZGIjQ0VHYpNmefZejbNTT0dMbaCZGIbOojuywiIrqN3q388c34CPi5aXA6Kx+D5u3C8TS97LIUTdqEb9ZSV/uobDiSjte+OYYSkxntdB74bERn+LppZJdFRESVcCmvCC8s34/EzHw4O9rj/4Z2wMOt/GWXVaMUP+Eb3R0hBOb+chYvxx1BicmMvq39ETemG0MKEZENqe/hjG/GR+D+5r4oKjVh7MoDWLYrWXZZisSgYkNKysx4bc0xfLjlDABgzH3BWDCsE5zV9pIrIyKiqnJzcsSSkZ0xtGsQhABmfH8Kb393EiazTT/osDoGFRuhLypF1LJ9WHMwDXYq4N1BbfDv/q1gZ8fhx0REtsrR3g4fDG6DmH7X+2ku330B41YewLWSMsmVKQeDig1Izb2GJxfsxu5zOXBV22PJyC4Y3q2R7LKIiMgKVCoVxvdsgnnPdYTawQ5bEy7j6UXxuGwoll2aIjCoKNyR1DwMnr8LSZcL4O+uwerxEXgg1E92WUREZGX9wwLx1Zhu8HZV40S6AYPm7UJiJtezY1BRsJ9OZOLZT+NxpaAELQOvrxvRur5WdllERFRNOjXyxLqJ3RHi64pL+mI8tSAev53Jll2WVAwqCiSEwOLfzmPClwdRXGrGAy188c34CARqnWWXRkRE1SzI2wVrJ0QiPNgLBcYyjFq+H1/tS5FdljQMKgpTZjJj2oYTeP/HBAgBPN8tCItHdEY9jdVXOyAiIoXycFFj5ehwPNGhAUxmgalrj2PmpkSY6+CIIAYVBSkwluHFzw/giz0pUKmAN/u3xLsD28Chji9cRURUF6kd7PDh0+0wuXczAMDCHefw0leHUVxqklxZzZL6CXjmzBkMHDgQPj4+cHd3R48ePbBt2zaZJUmTqS/GkIXx2H46G06OdlgwrBNevC+Eqx8TEdVhKpUKk3s3x3+faQdHexU2Hs/Ac4v3IKfAKLu0GiM1qAwYMABlZWX49ddfcfDgQbRr1w4DBgxAZmamzLJq3KlL13t3J2QY4FNPjbixEXikTYDssoiISCEGd2iIlaPDoXV2xKGUPAyevxvnsgtkl1UjpK31c+XKFfj6+uK3337DfffdBwDIz8+Hu7s7tmzZgt69e9/yOKPRCKPxzyRpMBig0+lsdq2fbYmX8Y9Vh1BYYkJTv3pYFtUFOi8X2WUREZECncsuwKhl+5GSew1aZ0csGt4J3UK8ZZd1VxS/1o+3tzdatGiBzz//HIWFhSgrK8OiRYvg5+eHTp063fa42NhYaLVay6bT6WqwautaueciRq/Yj8ISEyKbeOPbCZEMKUREdFtNfOth3cRIdAjygL6oFMOX7MW6w2myy6pWUldPTktLw6BBg3Do0CHY2dnBz88PGzduRIcOHW57TG24o2I2C8RuSsDi368vQPVUp4b4YHBbqB3YaZaIiO6suNSEV1cfxcbjGQCAyb2b4eWHmtlUv0Zpd1RiYmKgUqkq3BITEyGEQHR0NPz8/PD7779j3759GDRoEB577DFkZGTc9v01Gg3c3d3LbbakqMSECV8etISUf/ZpjtlPhTGkEBFRpTk52mPu0A4Y37MJAODjrWfx6uqjKCkzS67M+qx+RyU7Oxs5OTkVtgkJCcHvv/+OPn364OrVq+XCRrNmzTB69GjExMRU6vtVNpEpQXa+ES9+fgBHU/OgtrfD7CFhGNi+geyyiIjIhn21LwVvrj8Bk1mgW4gXFj3fGVoXR9ll3VFlP7+tPouYr68vfH1979ju2rVrAAA7u/J3Euzs7GA2175EeDYrH1HL9iM9rwgeLo74dHhndA32kl0WERHZuKFdg1DfwxnRXx7CnvO5GLxgF5ZHdUWQd+3o8yjteUNERAQ8PT0xcuRIHD16FGfOnMFrr72G5ORk9O/fX1ZZ1WJ30hU8sWA30vOK0PiPqZEZUoiIyFp6Nr+x1IoTzmcXYvD8XTiUclV2WVYhLaj4+Pjgp59+QkFBAR588EF07twZO3fuxIYNG9CuXTtZZVndNwdSMWLpPuQXl6FzI0+sndgdIb71ZJdFRES1zJ+L17ojp7AEQz/dgx+P377Pp62QOurHGpTaR0UIgY+2nMHcX5MAAAPCAjFnSDs4OdpLroyIiGqzQmMZXo47jK0JlwEAU/uFYuz9ypvpXPHzqNRmxjITJn99xBJSoh9ogv97tgNDChERVTtXjQMWDe+MqMjGAIDYTYn49/oTKDPZZv9PLslrZVcLSzBu5UHsu5ALBzsVPhjcFk93sd1J6YiIyPbY26nw9uOt0cjbBe/8cAqr9qYg/WoRPnmuA9yclD8i6K94R8WKLlwpxBMLdmPfhVy4aRywfFRXhhQiIpJmVPdgfDq8M5wd7bHjTDaGLIzHpbwi2WVVCYOKlRy4kIvB83ch+UohGng449uJkejRzEd2WUREVMc93Mofq8dFwNdNg8TMfAyevwsn0vWyy6o0BhUr+P7oJTz32V5cvVaKsIZarIuORHN/N9llERERAQDaNtRi3cRINPevhyyDEU8viscvCVmyy6oUBpV7IITAvG1JeOmrwygpM+PhVv6IG9sNfm5OsksjIiIqp6GnC9ZMiMR9zXxwrcSEMZ8fwIrdF2SXdUcMKnep1GRGzLfHMXvzaQDA6B7BWPh8J7io2T+ZiIiUyd3JEUujuuDZLjqYBTD9u5N45/tTMJmVO1MJP1Xvgr6oFBO/PIhdSTmwUwHTH2uNkX8MAyMiIlIyR3s7xD7RFkHeLvjPT6exdFcyUq9ew/+eba/I/9mWekfl0KFDePjhh+Hh4QFvb2+MHTsWBQUFMku6o7Sr1zBk4W7sSsqBi9oei0d0ZkghIiKbolKpMLFXU8wd2gFqBztsOZWFZz/dg8v5xbJLu4m0oHLp0iX07t0bTZs2xd69e/HTTz/h5MmTiIqKklXSHR1Ly8Pg+btxJqsAfm4arB4XgYda+ssui4iI6K481q4+Vr0YDk8XRxxL02PwvN04k5Uvu6xypE2h/+mnn2LatGnIyMiwrKB8/PhxhIWF4ezZs2jatGml3qemptD/+WQmJsUdRnGpGaEBblga1QX1PZyr7fsRERHVlAtXCvHC8v04f6UQbhoHLHi+U7VPsaH4KfSNRiPUarUlpACAs/P1D/6dO3dWeJzBYCi3VSchBJbsTMa4Lw6iuNSM+/9YoZIhhYiIaovGPq74dkIkujb2Qr6xDFHL9mH1/lTZZQGQGFQefPBBZGZmYvbs2SgpKcHVq1cRExMDAMjIuP1qj7GxsdBqtZZNp6u+mV9NZoG3vzuJd384BSGAoV2DsGRkZ5ubfpiIiOhOPF3VWPliVwxsXx9lZoHXvz2G2ZsTYZY8IsjqQSUmJgYqlarCLTExEa1bt8aKFSvw4YcfwsXFBQEBAQgODoa/v3+5uyx/N3XqVOj1esuWmlo9ia/QWIaxnx/AiviL179vv1B8MLgNHO05opuIiGonjYM9Pn6mPSY9eL37xbxt5/Dy10dQXGqSVpPV+6hkZ2cjJyenwjYhISFQq9WW11lZWXB1dYVKpYK7uzvi4uIwZMiQSn2/6uqj8uKK/diacBkaBzv895n2eLRtoNXem4iISOm+OZCKqWuPo8ws8ESHBvjomfZWff/Kfn5bfcC0r68vfH19q3SMv//1kTNLly6Fk5MTHn74YWuXVWVTHm6BM1kF+PjZ9ugY5Cm7HCIioho1pLMODTycEbP2OP7xYOUGuFQHaaN+AOCTTz5BZGQk6tWrhy1btuC1117DzJkzMWnSpEq/R3WO+ik1mfmoh4iI6rTq+iyUdkelKvbt24fp06ejoKAAoaGhWLRoEYYPHy6zpHIYUoiIqK6T/VkoNah8/vnnMr89ERERKRxvGRAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWJVW1B5//33ERkZCRcXF3h4eNyyTUpKCvr37w8XFxf4+fnhtddeQ1lZWXWVRERERDam2qbQLykpwZAhQxAREYElS5bc9HWTyYT+/fsjICAAu3fvRkZGBkaMGAFHR0d88MEH1VUWERER2ZBqXz15+fLlmDx5MvLy8srt37RpEwYMGIBLly7B398fALBw4UK88cYbyM7OhlqtvuX7GY1GGI1Gy2u9Xo+goCCkpqZaffVkIiIiqh4GgwE6nQ55eXnQarW3bSdtUcL4+Hi0bdvWElIAoG/fvpgwYQJOnjyJDh063PK42NhYzJgx46b9Op2u2molIiKi6pGfn6/MoJKZmVkupACwvM7MzLztcVOnTsWUKVMsr81mM3Jzc+Ht7Q2VSmXVGm+kvdp6t4bnZ/tq+zny/GxfbT9Hnt/dE0IgPz8f9evXr7BdlYJKTEwMZs2aVWGbhIQEhIaGVuVtq0Sj0UCj0ZTbd7vOutbi7u5eK/8B3sDzs321/Rx5fravtp8jz+/uVHQn5YYqBZVXX30VUVFRFbYJCQmp1HsFBARg37595fZlZWVZvkZERERUpaDi6+sLX19fq3zjiIgIvP/++7h8+TL8/PwAAFu2bIG7uztatWplle9BREREtq3a+qikpKQgNzcXKSkpMJlMOHLkCACgadOmqFevHvr06YNWrVph+PDh+M9//oPMzEy8+eabiI6OvunRjiwajQbTp09XTD3WxvOzfbX9HHl+tq+2nyPPr/pV2/DkqKgorFix4qb927ZtQ69evQAAFy9exIQJE7B9+3a4urpi5MiRmDlzJhwcpPXxJSIiIgWp9nlUiIiIiO4W1/ohIiIixWJQISIiIsViUCEiIiLFYlAhIiIixarTQeX9999HZGQkXFxcKj27rRACb731FgIDA+Hs7IzevXvj7Nmz5drk5uZi2LBhcHd3h4eHB0aPHo2CgoJqOIOKVbWOCxcuQKVS3XL75ptvLO1u9fW4uLiaOKWb3M3PulevXjfVP378+HJtUlJS0L9/f7i4uMDPzw+vvfYaysrKqvNUbqmq55ebm4uXXnoJLVq0gLOzM4KCgjBp0iTo9fpy7WRew3nz5qFx48ZwcnJCeHj4TRM//t0333yD0NBQODk5oW3btvjxxx/Lfb0yv5M1qSrnt3jxYtx3333w9PSEp6cnevfufVP7qKiom67VI488Ut2ncVtVOb/ly5ffVLuTk1O5Nkq7fkDVzvFWf09UKhX69+9vaaOka/jbb7/hscceQ/369aFSqbB+/fo7HrN9+3Z07NgRGo0GTZs2xfLly29qU9Xf6yoRddhbb70lPvroIzFlyhSh1WordczMmTOFVqsV69evF0ePHhWPP/64CA4OFkVFRZY2jzzyiGjXrp3Ys2eP+P3330XTpk3F0KFDq+ksbq+qdZSVlYmMjIxy24wZM0S9evVEfn6+pR0AsWzZsnLt/nr+NeluftY9e/YUY8aMKVe/Xq+3fL2srEy0adNG9O7dWxw+fFj8+OOPwsfHR0ydOrW6T+cmVT2/48ePiyeeeEJ89913IikpSfzyyy+iWbNm4sknnyzXTtY1jIuLE2q1WixdulScPHlSjBkzRnh4eIisrKxbtt+1a5ewt7cX//nPf8SpU6fEm2++KRwdHcXx48ctbSrzO1lTqnp+zz33nJg3b544fPiwSEhIEFFRUUKr1Yq0tDRLm5EjR4pHHnmk3LXKzc2tqVMqp6rnt2zZMuHu7l6u9szMzHJtlHT9hKj6Oebk5JQ7vxMnTgh7e3uxbNkySxslXcMff/xR/Pvf/xZr164VAMS6desqbH/+/Hnh4uIipkyZIk6dOiXmzp0r7O3txU8//WRpU9WfWVXV6aByw7JlyyoVVMxmswgICBCzZ8+27MvLyxMajUZ89dVXQgghTp06JQCI/fv3W9ps2rRJqFQqkZ6ebvXab8dadbRv31688MIL5fZV5h93Tbjbc+zZs6d4+eWXb/v1H3/8UdjZ2ZX7g7pgwQLh7u4ujEajVWqvDGtdw9WrVwu1Wi1KS0st+2Rdw65du4ro6GjLa5PJJOrXry9iY2Nv2f7pp58W/fv3L7cvPDxcjBs3TghRud/JmlTV8/u7srIy4ebmJlasWGHZN3LkSDFw4EBrl3pXqnp+d/rbqrTrJ8S9X8P//ve/ws3NTRQUFFj2Keka/lVl/g68/vrronXr1uX2PfPMM6Jv376W1/f6M7uTOv3op6qSk5ORmZmJ3r17W/ZptVqEh4cjPj4eABAfHw8PDw907tzZ0qZ3796ws7PD3r17a6xWa9Rx8OBBHDlyBKNHj77pa9HR0fDx8UHXrl2xdOlSCAnT8dzLOX755Zfw8fFBmzZtMHXqVFy7dq3c+7Zt27bc6t59+/aFwWDAyZMnrX8it2Gtf0t6vR7u7u43TaRY09ewpKQEBw8eLPf7Y2dnh969e1t+f/4uPj6+XHvg+rW40b4yv5M15W7O7++uXbuG0tJSeHl5ldu/fft2+Pn5oUWLFpgwYQJycnKsWntl3O35FRQUoFGjRtDpdBg4cGC53yElXT/AOtdwyZIlePbZZ+Hq6lpuvxKu4d240++gNX5md8IpYKsgMzMTAMp9gN14feNrmZmZlrWLbnBwcICXl5elTU2wRh1LlixBy5YtERkZWW7/O++8gwcffBAuLi74+eefMXHiRBQUFGDSpElWq78y7vYcn3vuOTRq1Aj169fHsWPH8MYbb+D06dNYu3at5X1vdY1vfK2mWOMaXrlyBe+++y7Gjh1bbr+Ma3jlyhWYTKZb/mwTExNvecztrsVff99u7Ltdm5pyN+f3d2+88Qbq169f7o/+I488gieeeALBwcE4d+4c/vWvf6Ffv36Ij4+Hvb29Vc+hIndzfi1atMDSpUsRFhYGvV6POXPmIDIyEidPnkTDhg0Vdf2Ae7+G+/btw4kTJ7BkyZJy+5VyDe/G7X4HDQYDioqKcPXq1Xv+d38ntS6oxMTEYNasWRW2SUhIQGhoaA1VZF2VPb97VVRUhFWrVmHatGk3fe2v+zp06IDCwkLMnj3bah9y1X2Of/3Qbtu2LQIDA/HQQw/h3LlzaNKkyV2/b2XV1DU0GAzo378/WrVqhbfffrvc16r7GlLVzZw5E3Fxcdi+fXu5DqfPPvus5b/btm2LsLAwNGnSBNu3b8dDDz0ko9RKi4iIQEREhOV1ZGQkWrZsiUWLFuHdd9+VWFn1WLJkCdq2bYuuXbuW22/L11AJal1QefXVVxEVFVVhm5CQkLt674CAAABAVlYWAgMDLfuzsrLQvn17S5vLly+XO66srAy5ubmW4+9FZc/vXutYs2YNrl27hhEjRtyxbXh4ON59910YjUarLFxVU+d4Q3h4OAAgKSkJTZo0QUBAwE091rOysgDAZq5hfn4+HnnkEbi5uWHdunVwdHSssL21r+Gt+Pj4wN7e3vKzvCErK+u25xMQEFBh+8r8TtaUuzm/G+bMmYOZM2di69atCAsLq7BtSEgIfHx8kJSUVKMfcvdyfjc4OjqiQ4cOSEpKAqCs6wfc2zkWFhYiLi4O77zzzh2/j6xreDdu9zvo7u4OZ2dn2Nvb3/O/izuySk8XG1fVzrRz5syx7NPr9bfsTHvgwAFLm82bN0vrTHu3dfTs2fOmkSK389577wlPT8+7rvVuWetnvXPnTgFAHD16VAjxZ2fav/ZYX7RokXB3dxfFxcXWO4E7uNvz0+v1olu3bqJnz56isLCwUt+rpq5h165dxT/+8Q/La5PJJBo0aFBhZ9oBAwaU2xcREXFTZ9qKfidrUlXPTwghZs2aJdzd3UV8fHylvkdqaqpQqVRiw4YN91xvVd3N+f1VWVmZaNGihXjllVeEEMq7fkLc/TkuW7ZMaDQaceXKlTt+D5nX8K9Qyc60bdq0Kbdv6NChN3WmvZd/F3es0yrvYqMuXrwoDh8+bBmCe/jwYXH48OFyQ3FbtGgh1q5da3k9c+ZM4eHhITZs2CCOHTsmBg4ceMvhyR06dBB79+4VO3fuFM2aNZM2PLmiOtLS0kSLFi3E3r17yx139uxZoVKpxKZNm256z++++04sXrxYHD9+XJw9e1bMnz9fuLi4iLfeeqvaz+dWqnqOSUlJ4p133hEHDhwQycnJYsOGDSIkJETcf//9lmNuDE/u06ePOHLkiPjpp5+Er6+vtOHJVTk/vV4vwsPDRdu2bUVSUlK54ZBlZWVCCLnXMC4uTmg0GrF8+XJx6tQpMXbsWOHh4WEZYTV8+HARExNjab9r1y7h4OAg5syZIxISEsT06dNvOTz5Tr+TNaWq5zdz5kyhVqvFmjVryl2rG3+D8vPzxT//+U8RHx8vkpOTxdatW0XHjh1Fs2bNajQ03+35zZgxQ2zevFmcO3dOHDx4UDz77LPCyclJnDx50tJGSddPiKqf4w09evQQzzzzzE37lXYN8/PzLZ91AMRHH30kDh8+LC5evCiEECImJkYMHz7c0v7G8OTXXntNJCQkiHnz5t1yeHJFP7N7VaeDysiRIwWAm7Zt27ZZ2uCP+SZuMJvNYtq0acLf319oNBrx0EMPidOnT5d735ycHDF06FBRr1494e7uLkaNGlUu/NSUO9WRnJx80/kKIcTUqVOFTqcTJpPppvfctGmTaN++vahXr55wdXUV7dq1EwsXLrxl25pQ1XNMSUkR999/v/Dy8hIajUY0bdpUvPbaa+XmURFCiAsXLoh+/foJZ2dn4ePjI1599dVyw3trSlXPb9u2bbf8Nw1AJCcnCyHkX8O5c+eKoKAgoVarRdeuXcWePXssX+vZs6cYOXJkufarV68WzZs3F2q1WrRu3Vps3Lix3Ncr8ztZk6pyfo0aNbrltZo+fboQQohr166JPn36CF9fX+Ho6CgaNWokxowZY7UPgLtRlfObPHmypa2/v7949NFHxaFDh8q9n9KunxBV/zeamJgoAIiff/75pvdS2jW83d+IG+c0cuRI0bNnz5uOad++vVCr1SIkJKTcZ+INFf3M7pVKCAnjSomIiIgqgfOoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFiMagQERGRYjGoEBERkWIxqBAREZFi/T8qIITBChX6pQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# log(1/(1 + exp(-x - eps) - 1/(1 + exp(-x + eps)))\n", "# plot \n", "inv_scale = 10\n", "f = lambda x : np.log(1/(1 + np.exp(- inv_scale * (x + 0.1))) - 1/(1 + np.exp(- inv_scale* (x - 0.1))))\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "x = np.linspace(-1, 1, 100)\n", "y = f(x)\n", "plt.plot(x, y)\n", "plt.yticks(np.arange(-10, 10, 1))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100, 0.9100,\n", " 0.9100])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_classes = 10\n", "beta = 0.1\n", "mat = torch.ones(num_classes, num_classes) * beta / num_classes\n", "# diagonal is 1 - (K - 1)/K * beta\n", "mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100, 0.0100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.9100,\n", " 0.0100],\n", " [0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100,\n", " 0.9100]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mat" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 1, 32, 32, 16])\n" ] } ], "source": [ "from typing import Dict, Tuple\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "from torchvision.utils import save_image, make_grid\n", "\n", "\n", "blk = lambda ic, oc: nn.Sequential(\n", " nn.Conv2d(ic, oc, 7, padding=3),\n", " nn.GroupNorm(oc // 8, oc, affine = False, eps = 1e-4),\n", " nn.LeakyReLU(),\n", ")\n", "\n", "\n", "class DummyX0Model(nn.Module):\n", " \"\"\"\n", " This should be unet-like, but let's don't think about the model too much :P\n", " Basically, any universal R^n -> R^n model should work.\n", " \"\"\"\n", "\n", " def __init__(self, n_channel: int, N) -> None:\n", " super(DummyX0Model, self).__init__()\n", " self.start = blk(n_channel, 16)\n", " self.pe = nn.Parameter(torch.randn(1, 16, 32, 32))\n", " self.conv = nn.Sequential(\n", " blk(16, 128),\n", " blk(128, 256),\n", " blk(256, 512),\n", " blk(512, 256),\n", " blk(256, 128),\n", " blk(128, 64),\n", " nn.Conv2d(64, 2 * n_channel, 3, padding=1),\n", " )\n", " self.N = N\n", "\n", " def forward(self, x, t) -> torch.Tensor:\n", " # Lets think about using t later. In the paper, they used Tr-like positional embeddings.\n", " st = x.float() / self.N - 0.5\n", " x = self.start(st) + self.pe\n", " y = self.conv(x)\n", " loc, log_scale = y.chunk(2, dim=1)\n", "\n", " return torch.tanh(loc + st), log_scale\n", "\n", "blk = lambda ic, oc: nn.Sequential(\n", " nn.Conv2d(ic, oc, 5, padding=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", " nn.Conv2d(oc, oc, 5, padding=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", " nn.Conv2d(oc, oc, 5, padding=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", ")\n", "\n", "blku = lambda ic, oc: nn.Sequential(\n", " nn.Conv2d(ic, oc, 5, padding=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", " nn.Conv2d(oc, oc, 5, padding=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", " nn.Conv2d(oc, oc, 5, padding=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", " nn.ConvTranspose2d(oc, oc, 2, stride=2),\n", " nn.GroupNorm(oc // 8, oc),\n", " nn.LeakyReLU(),\n", " \n", ")\n", "\n", "class DummyX0Model(nn.Module):\n", " \n", "\n", " def __init__(self, n_channel: int, N : int = 16) -> None:\n", " super(DummyX0Model, self).__init__()\n", " self.down1 = blk(n_channel, 16)\n", " self.down2 = blk(16, 32)\n", " self.down3 = blk(32, 64)\n", " self.down4 = blk(64, 512)\n", " self.down5 = blk(512, 512)\n", " self.up1 = blku(512, 512)\n", " self.up2 = blku(512 + 512, 64) # Corrected to account for concatenated feature maps\n", " self.up3 = blku(64 + 64, 32) # Corrected to account for concatenated feature maps\n", " self.up4 = blku(32 + 32, 16) # Corrected to account for concatenated feature maps\n", " self.convlast = blk(32, 16)\n", " self.final = nn.Conv2d(16, N * n_channel, 1, bias = False)\n", "\n", " # initialize final with zero\n", " #self.final.weight.data.zero_()\n", " #self.final.bias.data.zero_()\n", "\n", " self.tr1 = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n", " self.tr2 = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n", " self.tr3 = nn.TransformerEncoderLayer(d_model=64, nhead=8)\n", "\n", " self.temb_1 = nn.Linear(32, 16)\n", " self.temb_2 = nn.Linear(32, 32)\n", " self.temb_3 = nn.Linear(32, 64)\n", " self.temb_4 = nn.Linear(32, 512)\n", " self.N = N\n", "\n", " def forward(self, x, t) -> torch.Tensor:\n", " x = (2 * x.float() / self.N) - 1.0\n", " t = t.float().reshape(-1, 1) / 500\n", " t_as_sin = [torch.sin(t * 3.1415 * 2 ** i) for i in range(16)]\n", " t_as_cos = [torch.cos(t * 3.1415 * 2 ** i) for i in range(16)]\n", " # concat and send it to t_emb\n", " t_emb_1 = self.temb_1(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n", " t_emb_2 = self.temb_2(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n", " t_emb_3 = self.temb_3(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n", " t_emb_4 = self.temb_4(torch.cat(t_as_sin + t_as_cos, dim=1).to(x.device)).reshape(x.shape[0], -1, 1, 1)\n", " \n", " x1 = self.down1(x) + t_emb_1\n", " x2 = self.down2(nn.functional.avg_pool2d(x1, 2)) + t_emb_2\n", " x3 = self.down3(nn.functional.avg_pool2d(x2, 2)) + t_emb_3\n", " x4 = self.down4(nn.functional.avg_pool2d(x3, 2)) + t_emb_4\n", " x5 = self.down5(nn.functional.avg_pool2d(x4, 2))\n", "\n", " x5 = self.tr1(x5.reshape(x5.shape[0], x5.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(x5.shape)\n", "\n", " y = self.up1(x5)\n", "\n", " y = self.tr2(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(y.shape)\n", " \n", " y = self.up2(torch.cat([x4, y], dim=1))\n", "\n", " y = self.tr3(y.reshape(y.shape[0], y.shape[1], -1).transpose(1, 2)).transpose(1, 2).reshape(y.shape)\n", "\n", " y = self.up3(torch.cat([x3, y], dim=1))\n", "\n", " y = self.up4(torch.cat([x2, y], dim=1))\n", "\n", "\n", " y = self.convlast(torch.cat([x1, y], dim=1))\n", " y = self.final(y)\n", " # reshape to B, C, H, W, N\n", " y = y.reshape(y.shape[0], -1, self.N, *x.shape[2:]).transpose(2, -1).contiguous()\n", " \n", " return y\n", "\n", "\n", "# check if it takes 2, 1, 28, 28 input\n", "\n", "model = DummyX0Model(1, N = 16)\n", "print(model(torch.randn(2, 1, 32, 32), torch.tensor([1, 2])).shape)\n", "\n", "\n", "\n", "def get_logits_from_logistic_pars(loc, log_scale, num_classes = 10):\n", " loc = loc.unsqueeze(-1)\n", " log_scale = log_scale.unsqueeze(-1)\n", " inv_scale = (-log_scale + 2.0).exp()\n", "\n", " bin_width = 2.0 / (num_classes - 1)\n", " bin_centers = torch.linspace(-1.0, 1.0, num_classes).to(loc.device)\n", " bin_centers = bin_centers.reshape([1] * (len(loc.shape) - 1) + [num_classes])\n", " bin_centers = bin_centers - loc\n", " log_cdf_min = -torch.log1p((-inv_scale * (bin_centers - 0.5 * bin_width)).exp())\n", " log_cdf_plus = -torch.log1p((-inv_scale * (bin_centers + 0.5 * bin_width)).exp())\n", " logits = log_minus_exp(log_cdf_plus, log_cdf_min)\n", " return logits\n", "\n", "\n", "def log_minus_exp(a, b, epsilon=1e-6):\n", " return a + torch.log1p(-torch.exp(b - a) + epsilon)\n", "\n", "\n", "\n", "class D3PM(nn.Module):\n", " def __init__(\n", " self,\n", " x0_model: nn.Module,\n", " n_T: int,\n", " num_classes: int = 10,\n", " forward_type = 'uniform',\n", " hybrid_loss_coeff = 0.001\n", " ) -> None:\n", " super(D3PM, self).__init__()\n", " self.x0_model = x0_model\n", "\n", " self.n_T = n_T\n", " self.hybrid_loss_coeff = hybrid_loss_coeff\n", " self.beta_t = [1 / (self.n_T - t + 1) for t in range(1, self.n_T + 1)]\n", " self.eps = 1e-8\n", " self.num_classses = num_classes\n", " q_onestep_mats = []\n", " q_mats = [] # these are cumulative\n", "\n", " for beta in self.beta_t:\n", "\n", " if forward_type == 'uniform':\n", " mat = torch.ones(num_classes, num_classes) * beta / num_classes\n", " mat.diagonal().fill_(1 - (num_classes - 1) * beta / num_classes)\n", " q_onestep_mats.append(mat)\n", " else:\n", " raise NotImplementedError\n", " q_one_step_mats = torch.stack(q_onestep_mats, dim=0)\n", "\n", " q_one_step_transposed = q_one_step_mats.transpose(1, 2) # this will be used for q_posterior_logits\n", "\n", " q_mat_t = q_onestep_mats[0]\n", " q_mats = [q_mat_t]\n", " for idx in range(1, self.n_T):\n", " q_mat_t = q_mat_t @ q_onestep_mats[idx]\n", " q_mats.append(q_mat_t)\n", " q_mats = torch.stack(q_mats, dim=0)\n", " self.logit_type = 'logit'\n", "\n", " # register\n", " self.register_buffer(\"q_one_step_transposed\", q_one_step_transposed)\n", " self.register_buffer(\"q_mats\", q_mats)\n", "\n", " assert self.q_mats.shape == (self.n_T, num_classes, num_classes), self.q_mats.shape\n", " \n", " def _at(self, a, t, x):\n", " # t is 1-d, x is integer value of 0 to num_classes - 1\n", " bs = t.shape[0]\n", " t = t.reshape((bs, *[1] * (x.dim() - 1)))\n", " #out[i, j, k, l, m] = a[t[i, j, k, l], x[i, j, k, l], m]\n", " return a[t - 1, x, :]\n", "\n", "\n", " def q_posterior_logits(self, x_0, x_t, t):\n", " # if t == 1, this means we return the L_0 loss, so directly try to x_0 logits.\n", " # otherwise, we return the L_{t-1} loss.\n", " # Also, we never have t == 0.\n", "\n", " # if x_0 is integer, we convert it to one-hot.\n", " if x_0.dtype == torch.int64 or x_0.dtype == torch.int32:\n", " x_0_logits = torch.log(torch.nn.functional.one_hot(x_0, self.num_classses) + self.eps)\n", " else:\n", " x_0_logits = x_0.clone()\n", "\n", " assert x_0_logits.shape == x_t.shape + (self.num_classses,), print(f\"x_0_logits.shape: {x_0_logits.shape}, x_t.shape: {x_t.shape}\")\n", "\n", " # Here, we caclulate equation (3) of the paper. Note that the x_0 Q_t x_t^T is a normalizing constant, so we don't deal with that.\n", " # fact1 is \"guess of x_{t-1}\" from x_t\n", " # fact2 is \"guess of x_{t-1}\" from x_0\n", "\n", " fact1 = self._at(self.q_one_step_transposed, t, x_t)\n", " #fact2 = self._at_onehot(self.q_mats, t-1, )\n", " # x, a[t-1]\n", "\n", "\n", " softmaxed = torch.softmax(x_0_logits, dim=-1) # bs, ..., num_classes\n", " qmats2 = self.q_mats[t-2] # bs, num_classes, num_classes\n", "\n", " fact2 = torch.einsum('b...c,bcd->b...d', softmaxed, qmats2)\n", " \n", " #print(f\"Fact1Fact2\", fact1.shape, fact2.shape)\n", " out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)\n", " #print(f\"out: {out.shape}\")\n", " t_broadcast = t.reshape((t.shape[0], *[1] * (x_t.dim())))\n", " #print(t_broadcast.shape)\n", " bc = torch.where(t_broadcast == 1, x_0_logits, out)\n", " #print(f\"bc: {bc.shape}\")\n", " return bc\n", "\n", " def vb(self, dist1, dist2):\n", " out = (torch.softmax(dist1 + self.eps, dim = -1)*(torch.log_softmax(dist1 + self.eps, dim = -1) - torch.log_softmax(dist2 + self.eps, dim = -1)))\n", " return out.sum(dim=-1).mean()\n", "\n", " \n", " def q_sample(self, x_0, t, noise):\n", " # forward process, x_0 is the clean input.\n", " \n", " logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)\n", " noise = torch.clip(noise, self.eps, 1.0)\n", " gumbel_noise = -torch.log(-torch.log(noise))\n", " return torch.argmax(logits + gumbel_noise, dim=-1)\n", "\n", " def model_predict(self, x_0, t):\n", " if self.logit_type == 'logit':\n", " predicted_x0_logits = self.x0_model(x_0, t)\n", " else:\n", " loc, log_scale = self.x0_model(x_0, t)\n", " predicted_x0_logits = get_logits_from_logistic_pars(loc, log_scale, self.num_classses)\n", " return predicted_x0_logits\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Makes forward diffusion x_t from x_0, and tries to guess x_0 value from x_t using x0_model.\n", " x is one-hot of dim (bs, ...), with int values of 0 to num_classes - 1\n", " \"\"\"\n", " t = torch.randint(1, self.n_T, (x.shape[0],), device=x.device)\n", " x_t = self.q_sample(x, t, torch.rand((*x.shape, self.num_classses), device=x.device))\n", " # x_t is same shape as x\n", " assert x_t.shape == x.shape, print(f\"x_t.shape: {x_t.shape}, x.shape: {x.shape}\")\n", " # we use hybrid loss.\n", " \n", " predicted_x0_logits = self.model_predict(x, t)\n", "\n", "\n", " # based on this, we first do vb loss.\n", " true_q_posterior_logits = self.q_posterior_logits(x, x_t, t)\n", " #print(f\"predicted_x0_logits: {predicted_x0_logits.shape}, true_q_posterior_logits: {true_q_posterior_logits.shape}\")\n", " pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)\n", "\n", " vb_loss = self.vb(true_q_posterior_logits,pred_q_posterior_logits)\n", "\n", "\n", "\n", " predicted_x0_logits = predicted_x0_logits.flatten(start_dim = 0, end_dim = -2)\n", " x = x.flatten(start_dim = 0, end_dim = -1)\n", " #print(f\"predicted_x0_logits: {predicted_x0_logits.shape}, x: {x.shape}\")\n", "\n", "\n", " ce_loss = torch.nn.CrossEntropyLoss()(predicted_x0_logits, x)\n", "\n", " return vb_loss + ce_loss*self.hybrid_loss_coeff, {\"vb_loss\": vb_loss.detach().item(), \"ce_loss\": ce_loss.detach().item()}\n", "\n", " def p_sample(self, x, t, noise):\n", " \n", " predicted_x0_logits = self.model_predict(x, t)\n", " pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x, t)\n", "\n", " noise = torch.clip(noise, self.eps, 1.0)\n", "\n", " not_first_step = (t != 1).float().reshape((x.shape[0], *[1] * (x.dim())))\n", " gumbel_noise = -torch.log(-torch.log(noise))\n", " sample = torch.argmax(pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1)\n", " return sample\n", "\n", " def sample(self, x = None):\n", " for t in reversed(range(1, self.n_T)):\n", " t = torch.tensor([t]*x.shape[0], device=x.device)\n", " x = self.p_sample(x, t, torch.rand((*x.shape, self.num_classses), device=x.device))\n", " \n", " return x\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#d3pm = D3PM(DummyX0Model(1, 20), 1000, num_classes = 20).cuda()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#d3pm.sample(torch.randint(0, 2, (2, 1, 32, 32)).cuda())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# d3pm.train()\n", "# d3pm.forward(torch.randint(0, 2, (128, 1, 28, 28)).cuda())\n", "# dataset = MNIST(\n", "# \"./data\",\n", "# train=True,\n", "# download=True,\n", "# transform=transforms.ToTensor()\n", "# )\n", "# dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=20)\n", "# d3pm(x = next(iter(dataloader))[0].cuda().long())\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total Param Count: 63336704\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/root/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 30, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", " warnings.warn(_create_warning_msg(\n" ] } ], "source": [ "N = 16\n", "d3pm = D3PM(DummyX0Model(1, N), 1000, num_classes = N, hybrid_loss_coeff=0.0).cuda()\n", "print(f\"Total Param Count: {sum([p.numel() for p in d3pm.x0_model.parameters()])}\")\n", "dataset = MNIST(\n", " \"./data\",\n", " train=True,\n", " download=True,\n", " transform= transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Pad(2),\n", " ])\n", " )\n", "dataloader = DataLoader([dataset[0]] * 50000, batch_size=32, shuffle=True, num_workers=32)\n", "#dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=32)\n", "\n", "optim = torch.optim.AdamW(d3pm.x0_model.parameters(), lr=4e-4, betas=(0.95, 0.99))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/1563 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "loss: 0.0026, norm: 0.0014, param_norm: 935.4885, vb_loss: 0.0008, ce_loss: 0.7653: 19%|█▉ | 298/1563 [00:23<00:50, 25.17it/s]" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVCUlEQVR4nO3df2xV9f3H8VdL29tqube0rPfSwZVuklUHONZCuWL2y7sAMwqj24SwUZXMwIqjNpmIDvbHxkq2ZKKLYrZlsGUyHIngZFPCCsJYaks76qyVirGRRrgXHfbegtJi72d/fOP9cvlRekv7ube9z0fySeg5h3vffX9o74vPOefeNGOMEQAAgCXpiS4AAACkFsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsGrYwseTTz6pyZMnKzs7W+Xl5WpsbByupwIAACNI2nB8tsuzzz6rZcuW6emnn1Z5ebk2bdqkHTt2qL29XYWFhf3+3UgkohMnTmjs2LFKS0sb6tIAAMAwMMaou7tbRUVFSk+/ytqGGQazZs0yVVVV0a/7+vpMUVGRqa2tverf7ezsNJIYDAaDwWCMwNHZ2XnV1/ohP+3S29ur5uZm+f3+6Lb09HT5/X7V19dfcnxPT4/C4XB0GD5kFwCAEWvs2LFXPWbIw8f777+vvr4+ud3umO1ut1uBQOCS42tra+VyuaLD6/UOdUkAAMCSgVwykfC7XdauXatQKBQdnZ2diS4JAAAMo4yhfsDx48drzJgxCgaDMduDwaA8Hs8lxzscDjkcjqEuAwAAJKkhX/nIyspSaWmp6urqotsikYjq6urk8/mG+ukAAMAIM+QrH5JUU1OjyspKlZWVadasWdq0aZPOnj2re++9dzieDgAAjCDDEj7uvvtuvffee1q/fr0CgYC+8IUv6KWXXrrkIlQAAJB6huVNxq5FOByWy+VKdBkAAGAQQqGQnE5nv8ck/G4XAACQWggfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwKqMRBeA+H3wwQfD/hzjxo0b9udAfIZj3pnnxLPx83wx5t0+5jkWKx8AAMAqwgcAALAq7vBx8OBB3XnnnSoqKlJaWpp27doVs98Yo/Xr12vChAnKycmR3+/XsWPHhqpeAAAwwsV9zcfZs2d1yy236L777tOiRYsu2f+LX/xCTzzxhP7whz+ouLhY69at09y5c9XW1qbs7OwhKXo0SsT5QAyPkTSXF9eazOeIk81Immfm9dow10Mv7vAxf/58zZ8//7L7jDHatGmTfvzjH2vBggWSpD/+8Y9yu93atWuXFi9efG3VAgCAEW9Ir/no6OhQIBCQ3++PbnO5XCovL1d9ff1l/05PT4/C4XDMAAAAo9eQho9AICBJcrvdMdvdbnd038Vqa2vlcrmiY9KkSUNZEgAASDIJf5+PtWvXqqamJvp1OBwe0QEk2c4NjpTzfyNNoueZebWDeU4NiZ5nKfXmekhXPjwejyQpGAzGbA8Gg9F9F3M4HHI6nTEDAACMXkMaPoqLi+XxeFRXVxfdFg6H1dDQIJ/PN5RPBQAARqi4T7ucOXNGb731VvTrjo4OtbS0KD8/X16vV9XV1frZz36mKVOmRG+1LSoq0sKFC4eyblxBqi3d2ZIMy7IAMFrEHT6ampr01a9+Nfr1J9drVFZWauvWrXrooYd09uxZ3X///erq6tJtt92ml156iff4AAAAkqQ0Y4xJdBEXCofDcrlciS5j0BL9P2RWPoZHouf1YsyzHYmed+bZjkTPszS65joUCl31+k0+2wUAAFjFygcAABgyrHwAAICkQ/gAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiVkegCAGC0++CDD664b9y4cRYrAZIDKx8AAMAqwgcAALCK8AEAAKzimg8AGGZc1wHEYuUDAABYFVf4qK2t1cyZMzV27FgVFhZq4cKFam9vjznm3LlzqqqqUkFBgXJzc1VRUaFgMDikRQMAgJErzRhjBnrwvHnztHjxYs2cOVMff/yxHnnkEbW2tqqtrU3XX3+9JGnlypX629/+pq1bt8rlcmnVqlVKT0/Xv/71rwE9RzgclsvlGtx3k4L6u4XvYiz9jl4X/jtgnkc25hIjXSgUktPp7PeYuMLHxd577z0VFhbqwIED+tKXvqRQKKRPfepT2rZtm771rW9Jko4ePaqbbrpJ9fX1mj179lUfk/ARH8IHJF6wRhPmEiPdQMLHNV3zEQqFJEn5+fmSpObmZp0/f15+vz96TElJibxer+rr6y/7GD09PQqHwzEDAACMXoMOH5FIRNXV1ZozZ46mTp0qSQoEAsrKylJeXl7MsW63W4FA4LKPU1tbK5fLFR2TJk0abEkAAGAEGPSttlVVVWptbdWhQ4euqYC1a9eqpqYm+nU4HCaAAJcRzyk2jA4XzzmnYUaPVH/L/UGFj1WrVmn37t06ePCgJk6cGN3u8XjU29urrq6umNWPYDAoj8dz2cdyOBxyOByDKQMAAIxAcZ12McZo1apV2rlzp/bt26fi4uKY/aWlpcrMzFRdXV10W3t7u44fPy6fzzc0FQMAgBEtrpWPqqoqbdu2Tc8//7zGjh0bvY7D5XIpJydHLpdLy5cvV01NjfLz8+V0OvXAAw/I5/MN6E4XAP+PO5mA0YPTprHiCh+bN2+WJH3lK1+J2b5lyxbdc889kqTHHntM6enpqqioUE9Pj+bOnaunnnpqSIoFAAAj3zW9z8dw4H0+4sP/jkcv5jY1pfqFiKNVKv08D/v7fAAAAMSLT7VNUkN1fnCkJ+jRhvO+qYl5H72Y28Fh5QMAAFhF+AAAAFYRPgAAgFVc85Egw3WekGs8ks9wzDXznHyY59HLxnUdqTbXrHwAAACrCB8AAMAqTrsMI06tpA6W3FMD85wamOfhx8oHAACwivABAACsInwAAACruObjGnFdR2pgnlMH5/tTA/OcWKx8AAAAqwgfAADAKk67DMJQLNexPJd8OLWSGpjn1MGpleTFygcAALCK8AEAAKwifAAAAKu45mMYcW4w+Q3VOWHmOrkxz6mBeR45WPkAAABWET4AAIBVhA8AAGBVmjHGJLqIC4XDYblcrkSXAQAABiEUCsnpdPZ7DCsfAADAqrjCx+bNmzV9+nQ5nU45nU75fD69+OKL0f3nzp1TVVWVCgoKlJubq4qKCgWDwSEvGgAAjFxxhY+JEydq48aNam5uVlNTk772ta9pwYIFev311yVJDz74oF544QXt2LFDBw4c0IkTJ7Ro0aJhKRwAAIxQ5hqNGzfO/O53vzNdXV0mMzPT7NixI7rvjTfeMJJMfX39gB8vFAoZSQwGg8FgMEbgCIVCV32tH/Q1H319fdq+fbvOnj0rn8+n5uZmnT9/Xn6/P3pMSUmJvF6v6uvrr/g4PT09CofDMQMAAIxecYeP1157Tbm5uXI4HFqxYoV27typm2++WYFAQFlZWcrLy4s53u12KxAIXPHxamtr5XK5omPSpElxfxMAAGDkiDt8fO5zn1NLS4saGhq0cuVKVVZWqq2tbdAFrF27VqFQKDo6OzsH/VgAACD5xf3ZLllZWbrxxhslSaWlpTp8+LAef/xx3X333ert7VVXV1fM6kcwGJTH47ni4zkcDjkcjvgrBwAAI9I1v89HJBJRT0+PSktLlZmZqbq6uui+9vZ2HT9+XD6f71qfBgAAjBJxrXysXbtW8+fPl9frVXd3t7Zt26aXX35Ze/bskcvl0vLly1VTU6P8/Hw5nU498MAD8vl8mj179nDVDwAARpi4wsepU6e0bNkynTx5Ui6XS9OnT9eePXv09a9/XZL02GOPKT09XRUVFerp6dHcuXP11FNPDUvhAABgZOKzXQAAwJDhs10AAEDSIXwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqqQLH0n2UTMAACAOA3kdT7rw0d3dnegSAADAIA3kdTzpPtU2EonoxIkTMsbI6/Wqs7Pzqp+Ol4rC4bAmTZpEf66A/vSP/vSP/vSP/lxZKvfGGKPu7m4VFRUpPb3/tY0MSzUNWHp6uiZOnKhwOCxJcjqdKTeB8aA//aM//aM//aM//aM/V5aqvXG5XAM6LulOuwAAgNGN8AEAAKxK2vDhcDj0k5/8RA6HI9GlJCX60z/60z/60z/60z/6c2X0ZmCS7oJTAAAwuiXtygcAABidCB8AAMAqwgcAALCK8AEAAKwifAAAAKuSNnw8+eSTmjx5srKzs1VeXq7GxsZEl2RdbW2tZs6cqbFjx6qwsFALFy5Ue3t7zDHnzp1TVVWVCgoKlJubq4qKCgWDwQRVnFgbN25UWlqaqquro9tSvT/vvvuuvvvd76qgoEA5OTmaNm2ampqaovuNMVq/fr0mTJignJwc+f1+HTt2LIEV29PX16d169apuLhYOTk5+uxnP6uf/vSnMR+KlUr9OXjwoO68804VFRUpLS1Nu3btitk/kF6cPn1aS5culdPpVF5enpYvX64zZ85Y/C6GT3/9OX/+vNasWaNp06bp+uuvV1FRkZYtW6YTJ07EPMZo7k/cTBLavn27ycrKMr///e/N66+/br7//e+bvLw8EwwGE12aVXPnzjVbtmwxra2tpqWlxXzjG98wXq/XnDlzJnrMihUrzKRJk0xdXZ1pamoys2fPNrfeemsCq06MxsZGM3nyZDN9+nSzevXq6PZU7s/p06fNDTfcYO655x7T0NBg3n77bbNnzx7z1ltvRY/ZuHGjcblcZteuXebVV181d911lykuLjYfffRRAiu3Y8OGDaagoMDs3r3bdHR0mB07dpjc3Fzz+OOPR49Jpf78/e9/N48++qh57rnnjCSzc+fOmP0D6cW8efPMLbfcYl555RXzz3/+09x4441myZIllr+T4dFff7q6uozf7zfPPvusOXr0qKmvrzezZs0ypaWlMY8xmvsTr6QMH7NmzTJVVVXRr/v6+kxRUZGpra1NYFWJd+rUKSPJHDhwwBjzf//gMzMzzY4dO6LHvPHGG0aSqa+vT1SZ1nV3d5spU6aYvXv3mi9/+cvR8JHq/VmzZo257bbbrrg/EokYj8djfvnLX0a3dXV1GYfDYf785z/bKDGh7rjjDnPffffFbFu0aJFZunSpMSa1+3Pxi+tAetHW1mYkmcOHD0ePefHFF01aWpp59913rdVuw+XC2cUaGxuNJPPOO+8YY1KrPwORdKddent71dzcLL/fH92Wnp4uv9+v+vr6BFaWeKFQSJKUn58vSWpubtb58+djelVSUiKv15tSvaqqqtIdd9wR0weJ/vz1r39VWVmZvv3tb6uwsFAzZszQb3/72+j+jo4OBQKBmP64XC6Vl5enRH9uvfVW1dXV6c0335Qkvfrqqzp06JDmz58vif5caCC9qK+vV15ensrKyqLH+P1+paenq6GhwXrNiRYKhZSWlqa8vDxJ9OdiSfeptu+//776+vrkdrtjtrvdbh09ejRBVSVeJBJRdXW15syZo6lTp0qSAoGAsrKyov+4P+F2uxUIBBJQpX3bt2/Xv//9bx0+fPiSfanen7ffflubN29WTU2NHnnkER0+fFg//OEPlZWVpcrKymgPLvezlgr9efjhhxUOh1VSUqIxY8aor69PGzZs0NKlSyUp5ftzoYH0IhAIqLCwMGZ/RkaG8vPzU65f586d05o1a7RkyZLoJ9vSn1hJFz5weVVVVWptbdWhQ4cSXUrS6Ozs1OrVq7V3715lZ2cnupykE4lEVFZWpp///OeSpBkzZqi1tVVPP/20KisrE1xd4v3lL3/RM888o23btunzn/+8WlpaVF1draKiIvqDQTt//ry+853vyBijzZs3J7qcpJV0p13Gjx+vMWPGXHJHQjAYlMfjSVBVibVq1Srt3r1b+/fv18SJE6PbPR6Pent71dXVFXN8qvSqublZp06d0he/+EVlZGQoIyNDBw4c0BNPPKGMjAy53e6U7s+ECRN08803x2y76aabdPz4cUmK9iBVf9Z+9KMf6eGHH9bixYs1bdo0fe9739ODDz6o2tpaSfTnQgPphcfj0alTp2L2f/zxxzp9+nTK9OuT4PHOO+9o79690VUPif5cLOnCR1ZWlkpLS1VXVxfdFolEVFdXJ5/Pl8DK7DPGaNWqVdq5c6f27dun4uLimP2lpaXKzMyM6VV7e7uOHz+eEr26/fbb9dprr6mlpSU6ysrKtHTp0uifU7k/c+bMueTW7DfffFM33HCDJKm4uFgejyemP+FwWA0NDSnRnw8//FDp6bG/AseMGaNIJCKJ/lxoIL3w+Xzq6upSc3Nz9Jh9+/YpEomovLzces22fRI8jh07pn/84x8qKCiI2Z/q/blEoq94vZzt27cbh8Nhtm7datra2sz9999v8vLyTCAQSHRpVq1cudK4XC7z8ssvm5MnT0bHhx9+GD1mxYoVxuv1mn379pmmpibj8/mMz+dLYNWJdeHdLsakdn8aGxtNRkaG2bBhgzl27Jh55plnzHXXXWf+9Kc/RY/ZuHGjycvLM88//7z5z3/+YxYsWDBqbyW9WGVlpfn0pz8dvdX2ueeeM+PHjzcPPfRQ9JhU6k93d7c5cuSIOXLkiJFkfvWrX5kjR45E79YYSC/mzZtnZsyYYRoaGsyhQ4fMlClTRs2tpP31p7e319x1111m4sSJpqWlJeb3dU9PT/QxRnN/4pWU4cMYY379618br9drsrKyzKxZs8wrr7yS6JKsk3TZsWXLlugxH330kfnBD35gxo0bZ6677jrzzW9+05w8eTJxRSfYxeEj1fvzwgsvmKlTpxqHw2FKSkrMb37zm5j9kUjErFu3zrjdbuNwOMztt99u2tvbE1StXeFw2Kxevdp4vV6TnZ1tPvOZz5hHH3005sUilfqzf//+y/6+qaysNMYMrBf//e9/zZIlS0xubq5xOp3m3nvvNd3d3Qn4boZef/3p6Oi44u/r/fv3Rx9jNPcnXmnGXPB2fgAAAMMs6a75AAAAoxvhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFb9D9oAs+8UV1pFAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "loss: 0.0011, norm: 0.0024, param_norm: 936.5659, vb_loss: 0.0005, ce_loss: 0.4201: 38%|███▊ | 598/1563 [00:44<00:38, 24.82it/s]" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWY0lEQVR4nO3df2yV5f3/8VdL6Slazikt6Tl0cKSbZOiAjRUoR8x0ehZgRmGQTQgb1ZEZWHHUJhPRwRI3VpYlE7cgZr9gy2Q4EsHJBoQVhGFKgY46EK0YmTTCOcyxnlNQWuy5Pn98v97pKVDOKe19nx/PR3Il7X1fPefd99X79J3run/kGGOMAAAAbJLrdAAAACC7UHwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbUXwAAABbDVjxsW7dOo0ePVoFBQWqrKzUoUOHBuqtAABAGskZiGe7vPjii1q4cKGef/55VVZWau3atdqyZYtaWlpUWlra68/GYjGdOXNGQ4cOVU5OTn+HBgAABoAxRu3t7SorK1Nu7nXmNswAmDJliqmurra+7+rqMmVlZaauru66P9va2mok0Wg0Go1GS8PW2tp63f/1/b7s0tnZqaamJgWDQWtbbm6ugsGgGhoarujf0dGhaDRqNcNDdgEASFtDhw69bp9+Lz4++OADdXV1yev1xm33er0KhUJX9K+rq5PH47Ga3+/v75AAAIBNEjllwvGrXVasWKFIJGK11tZWp0MCAAADKK+/X3D48OEaNGiQwuFw3PZwOCyfz3dFf5fLJZfL1d9hAACAFNXvMx/5+fmqqKhQfX29tS0Wi6m+vl6BQKC/3w4AAKSZfp/5kKTa2lpVVVVp0qRJmjJlitauXauLFy/q4YcfHoi3AwAAaWRAio8HH3xQ//nPf7Rq1SqFQiF94Qtf0M6dO684CRUAAGSfAbnJ2I2IRqPyeDxOhwEAAPogEonI7Xb32sfxq10AAEB2ofgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ovgAAAC2ynM6ACTvf//7X8J9a2pqrK/Xrl2b8M8NGzYsiYgwEOwY554Yd/v1dZylvo8142w/u8Y5XcaWmQ8AAGArig8AAGCrpIuP/fv36/7771dZWZlycnK0bdu2uP3GGK1atUojRozQkCFDFAwGdfLkyf6KFwAApLkcY4xJ5gd27Nih1157TRUVFZozZ462bt2q2bNnW/t/+tOfqq6uTr///e9VXl6ulStX6tixYzpx4oQKCgqu+/rRaFQejyfpXyRVdF/X67n2lsyaX6pJl3VEu1RVVV1z342cc+E0xjlxHM+Zo+dY9te5NU5IhbGNRCJyu9299kn6hNOZM2dq5syZV91njNHatWv1gx/8QLNmzZIk/eEPf5DX69W2bds0b968ZN8OAABkmH495+PUqVMKhUIKBoPWNo/Ho8rKSjU0NFz1Zzo6OhSNRuMaAADIXP1afIRCIUmS1+uN2+71eq19PdXV1cnj8Vht1KhR/RkSAABIMY7f52PFihWqra21vo9GoylfgCS61uvEmnAqrPdlilRe02ecB07383mcXutnnAdOqp23lehYp/LnUjL6debD5/NJksLhcNz2cDhs7evJ5XLJ7XbHNQAAkLn6tfgoLy+Xz+dTfX29tS0ajaqxsVGBQKA/3woAAKSppJddLly4oHfeecf6/tSpU2publZxcbH8fr9qamr04x//WGPGjLEutS0rK4u7HBfX130KLlOm2XAlxjn1DMSUeyZddp8p0nWcM2UpLuni48iRI/ryl79sff/J+RpVVVXauHGjHn/8cV28eFGPPPKI2tradOedd2rnzp0J3eMDAABkvqSLj7vvvlu93ZcsJydHTz/9tJ5++ukbCgwAAGQmnu0CAABslfTt1QdaJt1ePRmZso6XLRjnzNXb2CZ6jg7jnN6SOb4Z6yslcnt1Zj4AAICtKD4AAICtKD4AAICtHL+9eqbpbf2Pa/szB+OcuRJdw+feHZmL43vgMfMBAABsRfEBAABsxbJLiug5lcflW5mJcc4OjHP26D7WjHPimPkAAAC2ovgAAAC2ovgAAAC24pwPG/H49Ozg9CWYnG/gDI7v7OD08Z0pmPkAAAC2ovgAAAC2YtklRfU2lTd79mzr63379tkQDQbKQDw9k2WW5PTXMlVvY8lTUrMD45w4Zj4AAICtKD4AAICtKD4AAICtcowxxukguotGo/J4PE6HYbuBulwr29cVU01/jTPj2n/6a53ejksuGffUMxDjnu7jHIlE5Ha7e+3DzAcAALAVxQcAALAVxQcAALAV53ykgRtZU+SeIOmDewSkhkTHYaDGoLdHtHPr/IExUHnt6zGd7uPMOR8AACDlJFV81NXVafLkyRo6dKhKS0s1e/ZstbS0xPW5dOmSqqurVVJSosLCQs2dO1fhcLhfgwYAAOkrqWWXGTNmaN68eZo8ebI+/vhjPfnkkzp+/LhOnDihm2++WZK0ZMkS/fWvf9XGjRvl8Xi0dOlS5ebm6rXXXkvoPVh2SQ5T9dmjt7FmbAeG0zm/3vHNuPcPJ/KcyZ/diSy7JPVsl507d8Z9v3HjRpWWlqqpqUlf+tKXFIlE9Nvf/labNm3SPffcI0nasGGDbrvtNh08eFBTp05N8lcAAACZ5obO+YhEIpKk4uJiSVJTU5MuX76sYDBo9Rk7dqz8fr8aGhqu+hodHR2KRqNxDQAAZK4+Fx+xWEw1NTWaNm2axo0bJ0kKhULKz89XUVFRXF+v16tQKHTV16mrq5PH47HaqFGj+hoSAABIA0ktu3RXXV2t48eP68CBAzcUwIoVK1RbW2t9H41GKUCAJPV2eSb6zolc2nGbdsTjmLFfn4qPpUuXavv27dq/f79Gjhxpbff5fOrs7FRbW1vc7Ec4HJbP57vqa7lcLrlcrr6EAQAA0lBSyy7GGC1dulRbt27Vnj17VF5eHre/oqJCgwcPVn19vbWtpaVFp0+fViAQ6J+IAQBAWktq5qO6ulqbNm3Syy+/rKFDh1rncXg8Hg0ZMkQej0eLFi1SbW2tiouL5Xa79eijjyoQCHClCxzBpYpwwvXuUMnSSv/g+E5fSRUf69evlyTdfffdcds3bNighx56SJL0zDPPKDc3V3PnzlVHR4emT5+u5557rl+CBQAA6S+p4iOR+5EVFBRo3bp1WrduXZ+DAgAAmYtnuwAAAFv1+VLbdJfqTw286667rK+3bdvmXCBpLhXW2quqqqyv//3vf1+zH+Oc3nr72+p+PN+IVPuccloqXArdPYae48wxfW3MfAAAAFtRfAAAAFtRfAAAAFvlmEQuYbFRNBqVx+NxOowB0X3tv6e1a9f2y3uwJuw8xjk79Bzn7ufzDNRaP+PuDI7p5EQiEbnd7l77MPMBAABsRfEBAABslbWX2tqh52VXfZ2ey6TpuEw1EJfwMu6pp/v0O9PtmavnMkt/jDXjHI+ZDwAAYCuKDwAAYCuKDwAAYCvO+egDLrvKDoxzdmCcswPjnFqY+QAAALai+AAAALZi2eUq7HjyKdNzqYGxzkxOPL2YcXYex3P6YOYDAADYiuIDAADYiuIDAADYinM+/j/WCrODE+cCABgYHM/pi5kPAABgK4oPAABgK4oPAABgqxxjjHE6iO6i0ag8Ho/TYQAAgD6IRCJyu9299mHmAwAA2Cqp4mP9+vWaMGGC3G633G63AoGAduzYYe2/dOmSqqurVVJSosLCQs2dO1fhcLjfgwYAAOkrqeJj5MiRWrNmjZqamnTkyBHdc889mjVrlt544w1J0mOPPaZXXnlFW7Zs0b59+3TmzBnNmTNnQAIHAABpytygYcOGmd/85jemra3NDB482GzZssXa9+abbxpJpqGhIeHXi0QiRhKNRqPRaLQ0bJFI5Lr/6/t8zkdXV5c2b96sixcvKhAIqKmpSZcvX1YwGLT6jB07Vn6/Xw0NDdd8nY6ODkWj0bgGAAAyV9LFx7Fjx1RYWCiXy6XFixdr69atuv322xUKhZSfn6+ioqK4/l6vV6FQ6JqvV1dXJ4/HY7VRo0Yl/UsAAID0kXTx8dnPflbNzc1qbGzUkiVLVFVVpRMnTvQ5gBUrVigSiVittbW1z68FAABSX9LPdsnPz9ett94qSaqoqNDhw4f17LPP6sEHH1RnZ6fa2triZj/C4bB8Pt81X8/lcsnlciUfOQAASEs3fJ+PWCymjo4OVVRUaPDgwaqvr7f2tbS06PTp0woEAjf6NgAAIEMkNfOxYsUKzZw5U36/X+3t7dq0aZNeffVV7dq1Sx6PR4sWLVJtba2Ki4vldrv16KOPKhAIaOrUqQMVPwAASDNJFR/nzp3TwoULdfbsWXk8Hk2YMEG7du3SV77yFUnSM888o9zcXM2dO1cdHR2aPn26nnvuuQEJHAAApCee7QIAAPoNz3YBAAAph+IDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYiuIDAADYKuWKjxR71AwAAEhCIv/HU674aG9vdzoEAADQR4n8H0+5p9rGYjGdOXNGxhj5/X61trZe9+l42SgajWrUqFHk5xrIT+/IT+/IT+/Iz7Vlc26MMWpvb1dZWZlyc3uf28izKaaE5ebmauTIkYpGo5Ikt9uddQOYDPLTO/LTO/LTO/LTO/JzbdmaG4/Hk1C/lFt2AQAAmY3iAwAA2Cpliw+Xy6Uf/vCHcrlcToeSkshP78hP78hP78hP78jPtZGbxKTcCacAACCzpezMBwAAyEwUHwAAwFYUHwAAwFYUHwAAwFYUHwAAwFYpW3ysW7dOo0ePVkFBgSorK3Xo0CGnQ7JdXV2dJk+erKFDh6q0tFSzZ89WS0tLXJ9Lly6purpaJSUlKiws1Ny5cxUOhx2K2Flr1qxRTk6OampqrG3Znp/3339f3/zmN1VSUqIhQ4Zo/PjxOnLkiLXfGKNVq1ZpxIgRGjJkiILBoE6ePOlgxPbp6urSypUrVV5eriFDhugzn/mMfvSjH8U9FCub8rN//37df//9KisrU05OjrZt2xa3P5FcnD9/XgsWLJDb7VZRUZEWLVqkCxcu2PhbDJze8nP58mUtX75c48eP180336yysjItXLhQZ86ciXuNTM5P0kwK2rx5s8nPzze/+93vzBtvvGG+853vmKKiIhMOh50OzVbTp083GzZsMMePHzfNzc3mq1/9qvH7/ebChQtWn8WLF5tRo0aZ+vp6c+TIETN16lRzxx13OBi1Mw4dOmRGjx5tJkyYYJYtW2Ztz+b8nD9/3txyyy3moYceMo2Njebdd981u3btMu+8847VZ82aNcbj8Zht27aZ119/3TzwwAOmvLzcfPTRRw5Gbo/Vq1ebkpISs337dnPq1CmzZcsWU1hYaJ599lmrTzbl529/+5t56qmnzEsvvWQkma1bt8btTyQXM2bMMJ///OfNwYMHzT/+8Q9z6623mvnz59v8mwyM3vLT1tZmgsGgefHFF81bb71lGhoazJQpU0xFRUXca2RyfpKVksXHlClTTHV1tfV9V1eXKSsrM3V1dQ5G5bxz584ZSWbfvn3GmP/3Bz948GCzZcsWq8+bb75pJJmGhganwrRde3u7GTNmjNm9e7e56667rOIj2/OzfPlyc+edd15zfywWMz6fz/zsZz+ztrW1tRmXy2X+9Kc/2RGio+677z7z7W9/O27bnDlzzIIFC4wx2Z2fnv9cE8nFiRMnjCRz+PBhq8+OHTtMTk6Oef/9922L3Q5XK856OnTokJFk3nvvPWNMduUnESm37NLZ2ammpiYFg0FrW25uroLBoBoaGhyMzHmRSESSVFxcLElqamrS5cuX43I1duxY+f3+rMpVdXW17rvvvrg8SOTnL3/5iyZNmqSvf/3rKi0t1cSJE/XrX//a2n/q1CmFQqG4/Hg8HlVWVmZFfu644w7V19fr7bffliS9/vrrOnDggGbOnCmJ/HSXSC4aGhpUVFSkSZMmWX2CwaByc3PV2Nhoe8xOi0QiysnJUVFRkSTy01PKPdX2gw8+UFdXl7xeb9x2r9ert956y6GonBeLxVRTU6Np06Zp3LhxkqRQKKT8/Hzrj/sTXq9XoVDIgSjtt3nzZv3zn//U4cOHr9iX7fl59913tX79etXW1urJJ5/U4cOH9b3vfU/5+fmqqqqycnC1Yy0b8vPEE08oGo1q7NixGjRokLq6urR69WotWLBAkrI+P90lkotQKKTS0tK4/Xl5eSouLs66fF26dEnLly/X/PnzrSfbkp94KVd84Oqqq6t1/PhxHThwwOlQUkZra6uWLVum3bt3q6CgwOlwUk4sFtOkSZP0k5/8RJI0ceJEHT9+XM8//7yqqqocjs55f/7zn/XCCy9o06ZN+tznPqfm5mbV1NSorKyM/KDPLl++rG984xsyxmj9+vVOh5OyUm7ZZfjw4Ro0aNAVVySEw2H5fD6HonLW0qVLtX37du3du1cjR460tvt8PnV2dqqtrS2uf7bkqqmpSefOndMXv/hF5eXlKS8vT/v27dMvfvEL5eXlyev1ZnV+RowYodtvvz1u22233abTp09LkpWDbD3Wvv/97+uJJ57QvHnzNH78eH3rW9/SY489prq6Oknkp7tEcuHz+XTu3Lm4/R9//LHOnz+fNfn6pPB47733tHv3bmvWQyI/PaVc8ZGfn6+KigrV19db22KxmOrr6xUIBByMzH7GGC1dulRbt27Vnj17VF5eHre/oqJCgwcPjstVS0uLTp8+nRW5uvfee3Xs2DE1NzdbbdKkSVqwYIH1dTbnZ9q0aVdcmv3222/rlltukSSVl5fL5/PF5ScajaqxsTEr8vPhhx8qNzf+I3DQoEGKxWKSyE93ieQiEAiora1NTU1NVp89e/YoFoupsrLS9pjt9knhcfLkSf39739XSUlJ3P5sz88VnD7j9Wo2b95sXC6X2bhxozlx4oR55JFHTFFRkQmFQk6HZqslS5YYj8djXn31VXP27Fmrffjhh1afxYsXG7/fb/bs2WOOHDliAoGACQQCDkbtrO5XuxiT3fk5dOiQycvLM6tXrzYnT540L7zwgrnpppvMH//4R6vPmjVrTFFRkXn55ZfNv/71LzNr1qyMvZS0p6qqKvOpT33KutT2pZdeMsOHDzePP/641Seb8tPe3m6OHj1qjh49aiSZn//85+bo0aPW1RqJ5GLGjBlm4sSJprGx0Rw4cMCMGTMmYy4l7S0/nZ2d5oEHHjAjR440zc3NcZ/XHR0d1mtkcn6SlZLFhzHG/PKXvzR+v9/k5+ebKVOmmIMHDzodku0kXbVt2LDB6vPRRx+Z7373u2bYsGHmpptuMl/72tfM2bNnnQvaYT2Lj2zPzyuvvGLGjRtnXC6XGTt2rPnVr34Vtz8Wi5mVK1car9drXC6Xuffee01LS4tD0dorGo2aZcuWGb/fbwoKCsynP/1p89RTT8X9s8im/Ozdu/eqnzdVVVXGmMRy8d///tfMnz/fFBYWGrfbbR5++GHT3t7uwG/T/3rLz6lTp675eb13717rNTI5P8nKMabb7fwAAAAGWMqd8wEAADIbxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALAVxQcAALDV/wHgUGBrLgpAzQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "loss: 0.0006, norm: 0.0004, param_norm: 938.1175, vb_loss: 0.0003, ce_loss: 0.2727: 58%|█████▊ | 900/1563 [01:05<00:27, 24.27it/s]" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVn0lEQVR4nO3df2yV5f3/8VdL29Nqe05pWc+ho0e6SYYOcKzQcsRsbp4FmFEYbBPCRlUyAyuO2mRidbDEjZVsyUQXxWzLYMtkuCaCk00JKwhjqS3tqLMiFWMjjXAOOtZzCkqLPdf3j292Phwopaec3ufX85FcCb3vq/d59331nL657vu+7gxjjBEAAIBFMuMdAAAASC8UHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFIUHwAAwFJjVnw8/fTTmjx5snJzc1VVVaXW1taxeikAAJBEMsbi2S7PP/+8VqxYoWeffVZVVVXavHmzGhsb1dXVpZKSkmG/NxQK6eTJkyooKFBGRkasQwMAAGPAGKO+vj6VlpYqM/MqcxtmDFRWVpqamprw14ODg6a0tNQ0NDRc9Xt7enqMJBqNRqPRaEnYenp6rvq3PuanXQYGBtTe3i6v1xvelpmZKa/Xq+bm5sv69/f3KxgMhpvhIbsAACStgoKCq/aJefHx4YcfanBwUE6nM2K70+mUz+e7rH9DQ4McDke4ud3uWIcEAAAsMpJLJuJ+t0t9fb0CgUC49fT0xDskAAAwhrJifcAJEyZo3Lhx8vv9Edv9fr9cLtdl/W02m2w2W6zDAAAACSrmMx85OTmqqKhQU1NTeFsoFFJTU5M8Hk+sXw4AACSZmM98SFJdXZ2qq6s1a9YsVVZWavPmzTp37pzuu+++sXg5AACQRMak+Ljnnnv0wQcfaMOGDfL5fPrCF76gV1555bKLUAEAQPoZk0XGrkUwGJTD4Yh3GAAAYBQCgYDsdvuwfeJ+twsAAEgvFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSFB8AAMBSWfEOANfmv//977D7x48fP+K+V/o+JIbhxo9xTh0jHeer9R3u+xB/YzHOQ31vomLmAwAAWIriAwAAWCrq4uPgwYO66667VFpaqoyMDO3atStivzFGGzZs0MSJE5WXlyev16vjx4/HKl4AAJDkor7m49y5c7rlllt0//33a/HixZft//nPf66nnnpKv//971VeXq7169dr3rx5Onr0qHJzc2MSdCqI5hxeMrwO/k88cp6K54QTXaKPM2LH6rynwzhHXXwsWLBACxYsGHKfMUabN2/Wj370Iy1cuFCS9Ic//EFOp1O7du3S0qVLry1aAACQ9GJ6zUd3d7d8Pp+8Xm94m8PhUFVVlZqbm4f8nv7+fgWDwYgGAABSV0yLD5/PJ0lyOp0R251OZ3jfpRoaGuRwOMKtrKwsliEBAIAEE/d1Purr61VXVxf+OhgMJlUBkmjn5jifb414jzvjbA3GOT3Ee5yl9BvrmM58uFwuSZLf74/Y7vf7w/suZbPZZLfbIxoAAEhdMS0+ysvL5XK51NTUFN4WDAbV0tIij8cTy5cCAABJKurTLmfPntU777wT/rq7u1sdHR0qKiqS2+1WbW2tfvrTn2rKlCnhW21LS0u1aNGiWMadMK5lGdxojjsWr4H4Y5zTA+OcHhjnkYu6+Ghra9NXvvKV8Nf/u16jurpa27Zt08MPP6xz587pgQceUG9vr2677Ta98sorrPEBAAAkjaL4uP3222WMueL+jIwMPf7443r88cevKTAAAJCaeLYLAACwVIYZbhojDoLBoBwOR7zDiAmWvIYU+XvAOKcuxjk9XPq5zlhfLhAIXPXOVWY+AACApSg+AACApSg+AACApeK+vHoqG6s1QJBcOCecHhjn9MA4xwYzHwAAwFIUHwAAwFKcdrHQxdN1nIIBAKQrZj4AAIClKD4AAIClKD4AAICluOYjTq52Gy63c6UmxhkAmPkAAAAWo/gAAACW4rQLAMQYp9fSx0ifZszvRCRmPgAAgKUoPgAAgKUoPgAAgKUyjDEm3kFcLBgMyuFwxDuMuBvt8uvpfh4x2UQzzoxtahrud4AxT27p+v4OBAKy2+3D9mHmAwAAWIriAwAAWIriAwAAWIp1PhJUNPeLI3ldPM5XG9eRrieA5HK1Ry0geUXz/k43zHwAAABLRVV8NDQ0aPbs2SooKFBJSYkWLVqkrq6uiD7nz59XTU2NiouLlZ+fryVLlsjv98c0aAAAkLyiutV2/vz5Wrp0qWbPnq1PPvlEjz76qDo7O3X06FFdf/31kqTVq1frr3/9q7Zt2yaHw6E1a9YoMzNT//znP0f0Gtxqe21Ywjd1cUtm+uH9nLpS+f08klttr2mdjw8++EAlJSU6cOCAvvSlLykQCOhTn/qUtm/frm9+85uSpGPHjummm25Sc3Oz5syZc9VjUnxcGz6sUlcqf1hhaLyfU1cqv5/HfJ2PQCAgSSoqKpIktbe368KFC/J6veE+U6dOldvtVnNz85DH6O/vVzAYjGgAACB1jbr4CIVCqq2t1dy5czVt2jRJks/nU05OjgoLCyP6Op1O+Xy+IY/T0NAgh8MRbmVlZaMNCQAAJIFR32pbU1Ojzs5OHTp06JoCqK+vV11dXfjrYDBIARIFbs+ExDinkuHe05yGSQ/pMM6jKj7WrFmj3bt36+DBg5o0aVJ4u8vl0sDAgHp7eyNmP/x+v1wu15DHstlsstlsowkDAAAkoahOuxhjtGbNGu3cuVP79u1TeXl5xP6KigplZ2erqakpvK2rq0snTpyQx+OJTcQAACCpRTXzUVNTo+3bt+vFF19UQUFB+DoOh8OhvLw8ORwOrVy5UnV1dSoqKpLdbteDDz4oj8czojtdEHupOF0HpCvez0gVURUfW7ZskSTdfvvtEdu3bt2qe++9V5L0xBNPKDMzU0uWLFF/f7/mzZunZ555JibBAgCA5HdN63yMBdb5iM7VLjjlf0qpY6TPhmDMk1sqr/+A/xPNs16SbdzHfJ0PAACAaPFU2yTA0xDTw7WMc7L9zyjd8Z5OD6Md53R4PzPzAQAALEXxAQAALEXxAQAALMU1HxYabglszg2mjuGWRua6jtTBOKcPPrtjj5kPAABgKYoPAABgKU67xEkqLzCDSCwOlh54T6cHxjk2mPkAAACWovgAAACWovgAAACW4pqPGIvVssmcK0xsjHN6YJzTA+NsPWY+AACApSg+AACApTjtMoThVi4cav9oMD2XGMZi5cJLMdaJjXFOHWO16uyVjonRY+YDAABYiuIDAABYiuIDAABYims+RmC4c4Wc/0sdjHPqYkns9MR7OnEx8wEAACxF8QEAACxF8QEAACyVYYwx8Q7iYsFgUA6HI95hAACAUQgEArLb7cP2YeYDAABYKqriY8uWLZoxY4bsdrvsdrs8Ho9efvnl8P7z58+rpqZGxcXFys/P15IlS+T3+2MeNAAASF5RFR+TJk3Spk2b1N7erra2Nn31q1/VwoUL9eabb0qSHnroIb300ktqbGzUgQMHdPLkSS1evHhMAgcAAEnKXKPx48eb3/72t6a3t9dkZ2ebxsbG8L633nrLSDLNzc0jPl4gEDCSaDQajUajJWELBAJX/Vs/6ms+BgcHtWPHDp07d04ej0ft7e26cOGCvF5vuM/UqVPldrvV3Nx8xeP09/crGAxGNAAAkLqiLj7eeOMN5efny2azadWqVdq5c6duvvlm+Xw+5eTkqLCwMKK/0+mUz+e74vEaGhrkcDjCraysLOofAgAAJI+oi4/Pfe5z6ujoUEtLi1avXq3q6modPXp01AHU19crEAiEW09Pz6iPBQAAEl/Uz3bJycnRjTfeKEmqqKjQ4cOH9eSTT+qee+7RwMCAent7I2Y//H6/XC7XFY9ns9lks9mijxwAACSla17nIxQKqb+/XxUVFcrOzlZTU1N4X1dXl06cOCGPx3OtLwMAAFJEVDMf9fX1WrBggdxut/r6+rR9+3a9+uqr2rNnjxwOh1auXKm6ujoVFRXJbrfrwQcflMfj0Zw5c8YqfgAAkGSiKj5Onz6tFStW6NSpU3I4HJoxY4b27Nmjr33ta5KkJ554QpmZmVqyZIn6+/s1b948PfPMM2MSOAAASE482wUAAMQMz3YBAAAJh+IDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYiuIDAABYKuGKjwR71AwAAIjCSP6OJ1zx0dfXF+8QAADAKI3k73jCPdU2FArp5MmTMsbI7Xarp6fnqk/HS0fBYFBlZWXk5wrIz/DIz/DIz/DIz5Wlc26MMerr61NpaakyM4ef28iyKKYRy8zM1KRJkxQMBiVJdrs97QYwGuRneORneORneORneOTnytI1Nw6HY0T9Eu60CwAASG0UHwAAwFIJW3zYbDb9+Mc/ls1mi3coCYn8DI/8DI/8DI/8DI/8XBm5GZmEu+AUAACktoSd+QAAAKmJ4gMAAFiK4gMAAFiK4gMAAFiK4gMAAFgqYYuPp59+WpMnT1Zubq6qqqrU2toa75As19DQoNmzZ6ugoEAlJSVatGiRurq6IvqcP39eNTU1Ki4uVn5+vpYsWSK/3x+niONr06ZNysjIUG1tbXhbuufn/fff13e+8x0VFxcrLy9P06dPV1tbW3i/MUYbNmzQxIkTlZeXJ6/Xq+PHj8cxYusMDg5q/fr1Ki8vV15enj772c/qJz/5ScRDsdIpPwcPHtRdd92l0tJSZWRkaNeuXRH7R5KLM2fOaPny5bLb7SosLNTKlSt19uxZC3+KsTNcfi5cuKB169Zp+vTpuv7661VaWqoVK1bo5MmTEcdI5fxEzSSgHTt2mJycHPO73/3OvPnmm+Z73/ueKSwsNH6/P96hWWrevHlm69atprOz03R0dJivf/3rxu12m7Nnz4b7rFq1ypSVlZmmpibT1tZm5syZY2699dY4Rh0fra2tZvLkyWbGjBlm7dq14e3pnJ8zZ86YG264wdx7772mpaXFvPvuu2bPnj3mnXfeCffZtGmTcTgcZteuXeb11183d999tykvLzcff/xxHCO3xsaNG01xcbHZvXu36e7uNo2NjSY/P988+eST4T7plJ+//e1v5rHHHjMvvPCCkWR27twZsX8kuZg/f7655ZZbzGuvvWb+8Y9/mBtvvNEsW7bM4p9kbAyXn97eXuP1es3zzz9vjh07Zpqbm01lZaWpqKiIOEYq5ydaCVl8VFZWmpqamvDXg4ODprS01DQ0NMQxqvg7ffq0kWQOHDhgjPn/v/DZ2dmmsbEx3Oett94ykkxzc3O8wrRcX1+fmTJlitm7d6/58pe/HC4+0j0/69atM7fddtsV94dCIeNyucwvfvGL8Lbe3l5js9nMn/70JytCjKs777zT3H///RHbFi9ebJYvX26MSe/8XPrHdSS5OHr0qJFkDh8+HO7z8ssvm4yMDPP+++9bFrsVhirOLtXa2mokmffee88Yk175GYmEO+0yMDCg9vZ2eb3e8LbMzEx5vV41NzfHMbL4CwQCkqSioiJJUnt7uy5cuBCRq6lTp8rtdqdVrmpqanTnnXdG5EEiP3/5y180a9Ysfetb31JJSYlmzpyp3/zmN+H93d3d8vl8EflxOByqqqpKi/zceuutampq0ttvvy1Jev3113Xo0CEtWLBAEvm52Ehy0dzcrMLCQs2aNSvcx+v1KjMzUy0tLZbHHG+BQEAZGRkqLCyURH4ulXBPtf3www81ODgop9MZsd3pdOrYsWNxiir+QqGQamtrNXfuXE2bNk2S5PP5lJOTE/7l/h+n0ymfzxeHKK23Y8cO/etf/9Lhw4cv25fu+Xn33Xe1ZcsW1dXV6dFHH9Xhw4f1gx/8QDk5Oaqurg7nYKj3Wjrk55FHHlEwGNTUqVM1btw4DQ4OauPGjVq+fLkkpX1+LjaSXPh8PpWUlETsz8rKUlFRUdrl6/z581q3bp2WLVsWfrIt+YmUcMUHhlZTU6POzk4dOnQo3qEkjJ6eHq1du1Z79+5Vbm5uvMNJOKFQSLNmzdLPfvYzSdLMmTPV2dmpZ599VtXV1XGOLv7+/Oc/67nnntP27dv1+c9/Xh0dHaqtrVVpaSn5wahduHBB3/72t2WM0ZYtW+IdTsJKuNMuEyZM0Lhx4y67I8Hv98vlcsUpqvhas2aNdu/erf3792vSpEnh7S6XSwMDA+rt7Y3ony65am9v1+nTp/XFL35RWVlZysrK0oEDB/TUU08pKytLTqczrfMzceJE3XzzzRHbbrrpJp04cUKSwjlI1/faD3/4Qz3yyCNaunSppk+fru9+97t66KGH1NDQIIn8XGwkuXC5XDp9+nTE/k8++URnzpxJm3z9r/B47733tHfv3vCsh0R+LpVwxUdOTo4qKirU1NQU3hYKhdTU1CSPxxPHyKxnjNGaNWu0c+dO7du3T+Xl5RH7KyoqlJ2dHZGrrq4unThxIi1ydccdd+iNN95QR0dHuM2aNUvLly8P/zud8zN37tzLbs1+++23dcMNN0iSysvL5XK5IvITDAbV0tKSFvn56KOPlJkZ+RE4btw4hUIhSeTnYiPJhcfjUW9vr9rb28N99u3bp1AopKqqKstjttr/Co/jx4/r73//u4qLiyP2p3t+LhPvK16HsmPHDmOz2cy2bdvM0aNHzQMPPGAKCwuNz+eLd2iWWr16tXE4HObVV181p06dCrePPvoo3GfVqlXG7Xabffv2mba2NuPxeIzH44lj1PF18d0uxqR3flpbW01WVpbZuHGjOX78uHnuuefMddddZ/74xz+G+2zatMkUFhaaF1980fz73/82CxcuTNlbSS9VXV1tPv3pT4dvtX3hhRfMhAkTzMMPPxzuk0756evrM0eOHDFHjhwxkswvf/lLc+TIkfDdGiPJxfz5883MmTNNS0uLOXTokJkyZUrK3Eo6XH4GBgbM3XffbSZNmmQ6OjoiPq/7+/vDx0jl/EQrIYsPY4z51a9+Zdxut8nJyTGVlZXmtddei3dIlpM0ZNu6dWu4z8cff2y+//3vm/Hjx5vrrrvOfOMb3zCnTp2KX9Bxdmnxke75eemll8y0adOMzWYzU6dONb/+9a8j9odCIbN+/XrjdDqNzWYzd9xxh+nq6opTtNYKBoNm7dq1xu12m9zcXPOZz3zGPPbYYxF/LNIpP/v37x/y86a6utoYM7Jc/Oc//zHLli0z+fn5xm63m/vuu8/09fXF4aeJveHy093dfcXP6/3794ePkcr5iVaGMRct5wcAADDGEu6aDwAAkNooPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKUoPgAAgKX+H9ZtKdLUYoywAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "loss: 0.0004, norm: 0.0002, param_norm: 938.9938, vb_loss: 0.0002, ce_loss: 0.1951: 77%|███████▋ | 1200/1563 [01:26<00:14, 25.02it/s]" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAACwCAYAAACviAzDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYT0lEQVR4nO3df0xV9/3H8ReIXGyRi2AAmVLZ2sx2/qhDRWoz+4NFXdP6K2s1bqI1a3TYSUlWazud6eJwW1Kxi7XZL+iyOp1JwdWtGoe/5oaoTNpaW2pTpqR6cZ2Fi7aChc/3j317cs9Vr1y8nHsv9/lITnJ+ce6H95tzeeecz+ecOGOMEQAAgEPiw90AAAAQWyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAoyg+AACAo/qs+Ni0aZNGjhyppKQk5efn68iRI331UQAAIIrE9cW7XbZt26aFCxfq5ZdfVn5+vsrLy7V9+3Y1NjYqIyMj4M92d3fr7NmzGjx4sOLi4kLdNAAA0AeMMWpvb1d2drbi429wbcP0gUmTJpni4mJruaury2RnZ5uysrIb/mxzc7ORxMTExMTExBSFU3Nz8w3/14f8tktnZ6fq6+tVWFhorYuPj1dhYaFqa2uv2r+jo0Ner9eaDC/ZBQAgag0ePPiG+4S8+Pj444/V1dWlzMxM2/rMzEx5PJ6r9i8rK5Pb7bamnJycUDcJAAA4pCddJsI+2mXVqlVqa2uzpubm5nA3CQAA9KGEUB9w6NChGjBggFpaWmzrW1palJWVddX+LpdLLpcr1M0AAAARKuRXPhITE5WXl6eamhprXXd3t2pqalRQUBDqjwMAAFEm5Fc+JKm0tFRFRUWaMGGCJk2apPLycl26dEmLFy/ui48DAABRpE+Kj8cee0z/+c9/tGbNGnk8Ht19993atWvXVZ1QAQBA7OmTh4zdDK/XK7fbHe5mAACAXmhra1NKSkrAfcI+2gUAAMQWig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOCohHA3AKFVUVFhW66srLTmFy1aZNs2a9as6x5n7dq1tuWNGzfeZMsQSuQ5NgTKs2TPNXmObr657m2eJXuuIznPXPkAAACOovgAAACOCrr4OHjwoB5++GFlZ2crLi5O1dXVtu3GGK1Zs0bDhg3ToEGDVFhYqFOnToWqvQAAIMoF3efj0qVLGjdunB5//HHNmTPnqu0///nP9eKLL+qVV15Rbm6uVq9erWnTpunkyZNKSkoKSaOjlf/9W183uo/XU/v377ct9/ReYUNDQ8Bl9Jx/nu+77z5rPjU1NSSf0ds8S/bckufeC/f5fKPPIc+9Fyi3/voi18Hk2V+05Dro4mPGjBmaMWPGNbcZY1ReXq4f/ehHmjlzpiTp97//vTIzM1VdXa158+bdXGsBAEDUC2mfj6amJnk8HhUWFlrr3G638vPzVVtbe82f6ejokNfrtU0AAKD/Cmnx4fF4JEmZmZm29ZmZmdY2f2VlZXK73dY0YsSIUDYJAABEmDhjjOn1D8fFqaqqyrof9c9//lNTpkzR2bNnNWzYMGu/Rx99VHFxcdq2bdtVx+jo6FBHR4e17PV6o6oAaWpqsi2H6p5+b/nfKwxkx44d1vxtt91m23b69OlQNalf6C95luy5Js/SuHHjbMu+nejJc//le06HO89Sz3MdKM9SZOS6ra1NKSkpAfcJ6ZWPrKwsSVJLS4ttfUtLi7XNn8vlUkpKim0CAAD9V0iLj9zcXGVlZammpsZa5/V6VVdXp4KCglB+FAAAiFJB33a5ePGiPvjgA0nS+PHj9cILL+j+++9XWlqacnJy9LOf/Uzr16+3DbV96623ejzU1uv1yu129+63CQP/S16hGubke5zW1lbbNt+hm/6GDBkSks+H3Rejt77g//jj3iLP4Td16lTb8r///W9rvi/OZ8mea/Ic+T755JMe7Xf33XfblsvLy635QHmW+leue3LbJeihtseOHdP9999vLZeWlkqSioqKVFlZqaefflqXLl3SE088odbWVt17773atWtXzD/jAwAA/E/Qxcd9992nQBdL4uLi9Pzzz+v555+/qYYBAID+iXe7AAAARwV95QN2/sOafPsC+A/fCubx5r63tvz59jMpKSm5URMRAv7D2/zvz/rmJFA/gd7mWSLXfeXAgQPX3eZ/zvb2dQXkObr5nu/+/T98c+3//2D27NnWPHm248oHAABwFMUHAABwFMUHAABw1E09Xr0vRNtzPoJRVVVlzfuP+fZ/tK5/HwNED988S/Zck+f+gzwD1+b449UBAABuhOIDAAA4itsuYeI/XGv//v22Zd8hWug//C/Vk+f+w/ec5nxGLOO2CwAAiDgUHwAAwFEUHwAAwFE8Xj1M/O8J+w+99X3Nd6DHPyPy+fbz8L/3T577D99zOtD5LJHraBao3xZ57jmufAAAAEdRfAAAAEcx1DZM/N9w6P8WTN9LuK2trdfd9sorr4S4ZQi1QG+8DZRnf4sXLw5hqxBqPc2zZM/12rVrbdv834yKyFJRUWFb9n27MXn+H4baAgCAiEPxAQAAHEXxAQAAHEWfjwgxc+ZM23JlZWWPfq6kpMS2XF1dbVtua2u7iVYhnAI9gp/HdUcX/1wOGTLkutvIc/QKJs++2/ob+nwAAICIQ/EBAAAcRfEBAAAcRZ+PCOX7zIDy8nLbNv9HN/vyHXMu8Xjf/sT3nrH//WLfc4Z+PkB4+D+/qbfP8gjVccKFPh8AACDiBFV8lJWVaeLEiRo8eLAyMjI0a9YsNTY22va5fPmyiouLlZ6eruTkZM2dO1ctLS0hbTQAAIheQd12mT59uubNm6eJEyfq888/17PPPqsTJ07o5MmTuvXWWyVJy5Yt01/+8hdVVlbK7XZr+fLlio+P1z/+8Y8efQa3Xa7mHw/fy+r+w7f8MWyvf4qlYXv9EbfJ0J/15LZLQjAH3LVrl225srJSGRkZqq+v1ze+8Q21tbXpt7/9rbZs2aIHHnhA0v+eg3/nnXfq8OHDmjx5cpC/AgAA6G9uqs/HFxV7WlqaJKm+vl5XrlxRYWGhtc+oUaOUk5Oj2traax6jo6NDXq/XNgEAgP6r18VHd3e3SkpKNGXKFI0ePVqS5PF4lJiYqNTUVNu+mZmZ8ng81zxOWVmZ3G63NY0YMaK3TQIAAFEgqNsuvoqLi3XixAkdOnTophqwatUqlZaWWster5cCxM/N3BP2HZY7depU2zaG4UaXQP17fLcx3DryBTqnb9SPyxd9faJXrD+KvVfFx/Lly7Vz504dPHhQw4cPt9ZnZWWps7NTra2ttqsfLS0tysrKuuaxXC6XXC5Xb5oBAACiUFC3XYwxWr58uaqqqrR3717l5ubatufl5WngwIGqqamx1jU2NurMmTMqKCgITYsBAEBUC+rKR3FxsbZs2aIdO3Zo8ODBVj8Ot9utQYMGye12a8mSJSotLVVaWppSUlL05JNPqqCggJEuIVRVVdXjfRctWmTNc/k9ugST54aGBmuePEeX3p7PiC43yrPvrZZYyHNQxcfmzZslXf1474qKCitYGzZsUHx8vObOnauOjg5NmzZNL730UkgaCwAAol9QxUdPnkeWlJSkTZs2adOmTb1uFAAA6L94twsAAHBUr4fa4ub4v7Vw5MiRtuWSkhJrPtBbbG/E997hjh07en0chIb/cOdQ5bm1tbXXP4vQC5Rnqfe55hyOPL65rqystG3zf+ZVT/n3+eiPeefKBwAAcBTFBwAAcBTFBwAAcBR9PhzU1NRkzfvf06uuru7VMX2f7yBJa9eutS3zzAfn+eZZsue6t3n2xyPUw6+v8uzbbyBUfy+4OU5/d8fC+cyVDwAA4CiKDwAA4Kg405MnhznI6/XK7XaHuxm95j/EzleoLs+Vl5db8/1xCFY0cDrPErkOh77Is2TPNXkOv0B5lkJzTsdSntva2pSSkhJwH658AAAAR1F8AAAAR1F8AAAARzHUthcqKiqsef/H6dKvo/8gz7HBN89SaIa60n8nMmzYsMGa93/UeaiGNMdqv46bxZUPAADgKIoPAADgKIba/r+qqipr3v8Nof6X627m7aNf2L9/v2159uzZN31M3Jj/Uyl989AXefb/DPLsjHHjxtmWfd8qS577D9/vbSnwd3df5Fki19fCUFsAABBxKD4AAICjKD4AAICj6PNxDZ988kmP9/V/u6jvvWV/3BuMPD3NdTB59n/rZVtbW5CtQqj19pzmfI4ufZFnyX5Ocz7fGH0+AABAxKH4AAAAjqL4AAAAjqLPBwAACBn6fAAAgIgTVPGxefNmjR07VikpKUpJSVFBQYHeeOMNa/vly5dVXFys9PR0JScna+7cuWppaQl5owEAQPQKqvgYPny41q9fr/r6eh07dkwPPPCAZs6cqXfeeUeS9NRTT+n111/X9u3bdeDAAZ09e1Zz5szpk4YDAIAoZW7SkCFDzG9+8xvT2tpqBg4caLZv325te/fdd40kU1tb2+PjtbW1GUlMTExMTExMUTi1tbXd8H99r/t8dHV1aevWrbp06ZIKCgpUX1+vK1euqLCw0Npn1KhRysnJUW1t7XWP09HRIa/Xa5sAAED/FXTx8fbbbys5OVkul0tLly5VVVWV7rrrLnk8HiUmJl71xsjMzEx5PJ7rHq+srExut9uaRowYEfQvAQAAokfQxcdXv/pVNTQ0qK6uTsuWLVNRUZFOnjzZ6wasWrVKbW1t1tTc3NzrYwEAgMiXEOwPJCYm6vbbb5ck5eXl6ejRo9q4caMee+wxdXZ2qrW11Xb1o6WlRVlZWdc9nsvlksvlCr7lAAAgKt30cz66u7vV0dGhvLw8DRw4UDU1Nda2xsZGnTlzRgUFBTf7MQAAoJ8I6srHqlWrNGPGDOXk5Ki9vV1btmzR/v37tXv3brndbi1ZskSlpaVKS0tTSkqKnnzySRUUFGjy5Ml91X4AABBlgio+zp8/r4ULF+rcuXNyu90aO3asdu/erW9+85uSpA0bNig+Pl5z585VR0eHpk2bppdeeqlPGg4AAKIT73YBAAAhw7tdAABAxKH4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjqL4AAAAjoq44iPCXjUDAACC0JP/4xFXfLS3t4e7CQAAoJd68n884t5q293drbNnz8oYo5ycHDU3N9/w7XixyOv1asSIEcTnOohPYMQnMOITGPG5vliOjTFG7e3tys7OVnx84GsbCQ61qcfi4+M1fPhweb1eSVJKSkrMJTAYxCcw4hMY8QmM+ARGfK4vVmPjdrt7tF/E3XYBAAD9G8UHAABwVMQWHy6XSz/+8Y/lcrnC3ZSIRHwCIz6BEZ/AiE9gxOf6iE3PRFyHUwAA0L9F7JUPAADQP1F8AAAAR1F8AAAAR1F8AAAAR1F8AAAAR0Vs8bFp0yaNHDlSSUlJys/P15EjR8LdJMeVlZVp4sSJGjx4sDIyMjRr1iw1Njba9rl8+bKKi4uVnp6u5ORkzZ07Vy0tLWFqcXitX79ecXFxKikpsdbFenw++ugjfec731F6eroGDRqkMWPG6NixY9Z2Y4zWrFmjYcOGadCgQSosLNSpU6fC2GLndHV1afXq1crNzdWgQYP0la98RT/5yU9sL8WKpfgcPHhQDz/8sLKzsxUXF6fq6mrb9p7E4sKFC1qwYIFSUlKUmpqqJUuW6OLFiw7+Fn0nUHyuXLmilStXasyYMbr11luVnZ2thQsX6uzZs7Zj9Of4BM1EoK1bt5rExETzu9/9zrzzzjvme9/7nklNTTUtLS3hbpqjpk2bZioqKsyJEydMQ0OD+da3vmVycnLMxYsXrX2WLl1qRowYYWpqasyxY8fM5MmTzT333BPGVofHkSNHzMiRI83YsWPNihUrrPWxHJ8LFy6Y2267zSxatMjU1dWZDz/80Ozevdt88MEH1j7r1683brfbVFdXmzfffNM88sgjJjc313z22WdhbLkz1q1bZ9LT083OnTtNU1OT2b59u0lOTjYbN2609oml+Pz1r381zz33nHnttdeMJFNVVWXb3pNYTJ8+3YwbN84cPnzY/P3vfze33367mT9/vsO/Sd8IFJ/W1lZTWFhotm3bZt577z1TW1trJk2aZPLy8mzH6M/xCVZEFh+TJk0yxcXF1nJXV5fJzs42ZWVlYWxV+J0/f95IMgcOHDDG/O8PfuDAgWb79u3WPu+++66RZGpra8PVTMe1t7ebO+64w+zZs8dMnTrVKj5iPT4rV640995773W3d3d3m6ysLPOLX/zCWtfa2mpcLpf54x//6EQTw+qhhx4yjz/+uG3dnDlzzIIFC4wxsR0f/3+uPYnFyZMnjSRz9OhRa5833njDxMXFmY8++sixtjvhWsWZvyNHjhhJ5vTp08aY2IpPT0TcbZfOzk7V19ersLDQWhcfH6/CwkLV1taGsWXh19bWJklKS0uTJNXX1+vKlSu2WI0aNUo5OTkxFavi4mI99NBDtjhIxOfPf/6zJkyYoG9/+9vKyMjQ+PHj9etf/9ra3tTUJI/HY4uP2+1Wfn5+TMTnnnvuUU1Njd5//31J0ptvvqlDhw5pxowZkoiPr57Eora2VqmpqZowYYK1T2FhoeLj41VXV+d4m8Otra1NcXFxSk1NlUR8/EXcW20//vhjdXV1KTMz07Y+MzNT7733XphaFX7d3d0qKSnRlClTNHr0aEmSx+NRYmKi9cf9hczMTHk8njC00nlbt27Vv/71Lx09evSqbbEenw8//FCbN29WaWmpnn32WR09elQ/+MEPlJiYqKKiIisG1zrXYiE+zzzzjLxer0aNGqUBAwaoq6tL69at04IFCyQp5uPjqyex8Hg8ysjIsG1PSEhQWlpazMXr8uXLWrlypebPn2+92Zb42EVc8YFrKy4u1okTJ3To0KFwNyViNDc3a8WKFdqzZ4+SkpLC3ZyI093drQkTJuinP/2pJGn8+PE6ceKEXn75ZRUVFYW5deH3pz/9Sa+++qq2bNmir33ta2poaFBJSYmys7OJD3rtypUrevTRR2WM0ebNm8PdnIgVcbddhg4dqgEDBlw1IqGlpUVZWVlhalV4LV++XDt37tS+ffs0fPhwa31WVpY6OzvV2tpq2z9WYlVfX6/z58/r61//uhISEpSQkKADBw7oxRdfVEJCgjIzM2M6PsOGDdNdd91lW3fnnXfqzJkzkmTFIFbPtR/+8Id65plnNG/ePI0ZM0bf/e539dRTT6msrEwS8fHVk1hkZWXp/Pnztu2ff/65Lly4EDPx+qLwOH36tPbs2WNd9ZCIj7+IKz4SExOVl5enmpoaa113d7dqampUUFAQxpY5zxij5cuXq6qqSnv37lVubq5te15engYOHGiLVWNjo86cORMTsXrwwQf19ttvq6GhwZomTJigBQsWWPOxHJ8pU6ZcNTT7/fff12233SZJys3NVVZWli0+Xq9XdXV1MRGfTz/9VPHx9q/AAQMGqLu7WxLx8dWTWBQUFKi1tVX19fXWPnv37lV3d7fy8/Mdb7PTvig8Tp06pb/97W9KT0+3bY/1+Fwl3D1er2Xr1q3G5XKZyspKc/LkSfPEE0+Y1NRU4/F4wt00Ry1btsy43W6zf/9+c+7cOWv69NNPrX2WLl1qcnJyzN69e82xY8dMQUGBKSgoCGOrw8t3tIsxsR2fI0eOmISEBLNu3Tpz6tQp8+qrr5pbbrnF/OEPf7D2Wb9+vUlNTTU7duwwb731lpk5c2a/HUrqr6ioyHzpS1+yhtq+9tprZujQoebpp5+29oml+LS3t5vjx4+b48ePG0nmhRdeMMePH7dGa/QkFtOnTzfjx483dXV15tChQ+aOO+7oN0NJA8Wns7PTPPLII2b48OGmoaHB9n3d0dFhHaM/xydYEVl8GGPML3/5S5OTk2MSExPNpEmTzOHDh8PdJMdJuuZUUVFh7fPZZ5+Z73//+2bIkCHmlltuMbNnzzbnzp0LX6PDzL/4iPX4vP7662b06NHG5XKZUaNGmV/96le27d3d3Wb16tUmMzPTuFwu8+CDD5rGxsYwtdZZXq/XrFixwuTk5JikpCTz5S9/2Tz33HO2fxaxFJ99+/Zd8/umqKjIGNOzWPz3v/818+fPN8nJySYlJcUsXrzYtLe3h+G3Cb1A8Wlqarru9/W+ffusY/Tn+AQrzhifx/kBAAD0sYjr8wEAAPo3ig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOAoig8AAOCo/wNZvNGki+7vXAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "loss: 0.0005, norm: 0.0005, param_norm: 939.0556, vb_loss: 0.0001, ce_loss: 0.1722: 81%|████████ | 1260/1563 [01:37<00:23, 12.98it/s]\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[4], line 31\u001b[0m\n\u001b[1;32m 28\u001b[0m norm \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mclip_grad_norm_(d3pm\u001b[38;5;241m.\u001b[39mx0_model\u001b[38;5;241m.\u001b[39mparameters(), \u001b[38;5;241m0.01\u001b[39m)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 31\u001b[0m param_norm \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m([\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mp\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m d3pm\u001b[38;5;241m.\u001b[39mx0_model\u001b[38;5;241m.\u001b[39mparameters()])\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m loss_ema \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 33\u001b[0m loss_ema \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mitem()\n", "File \u001b[0;32m~/discrete-diffusion/cu122py310/lib/python3.12/site-packages/torch/functional.py:1610\u001b[0m, in \u001b[0;36mnorm\u001b[0;34m(input, p, dim, keepdim, out, dtype)\u001b[0m\n\u001b[1;32m 1608\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m p \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfro\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(dim, (\u001b[38;5;28mint\u001b[39m, torch\u001b[38;5;241m.\u001b[39mSymInt)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(dim) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m 1609\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m out \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1610\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinalg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvector_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1611\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1612\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mvector_norm(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m2\u001b[39m, _dim, keepdim, dtype\u001b[38;5;241m=\u001b[39mdtype, out\u001b[38;5;241m=\u001b[39mout)\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "d3pm.train()\n", "\n", "n_epoch = 10\n", "device = 'cuda'\n", "from torchvision.utils import make_grid\n", "from matplotlib import pyplot as plt\n", "global_step = 0\n", "for i in range(n_epoch):\n", " d3pm.train()\n", " pbar = tqdm(dataloader)\n", " loss_ema = None\n", " for x, _ in pbar:\n", " #print(x)\n", " optim.zero_grad()\n", " x = x.to(device)\n", " # discritize x to 10 bins\n", " x = (x * N).long().clamp(0, N - 1)\n", " \n", " loss, info = d3pm(x)\n", " \n", " # if loss.item() > 1000 or torch.isnan(loss):\n", " # print(f\"loss is too high, skipping, {loss.item()}\")\n", " # loss.backward()\n", " # optim.zero_grad()\n", " # continue\n", " #print(loss.item())\n", " loss.backward()\n", " norm = torch.nn.utils.clip_grad_norm_(d3pm.x0_model.parameters(), 0.01)\n", "\n", " with torch.no_grad():\n", " param_norm = sum([torch.norm(p) for p in d3pm.x0_model.parameters()])\n", " if loss_ema is None:\n", " loss_ema = loss.item()\n", " else:\n", " loss_ema = 0.99 * loss_ema + 0.01 * loss.item()\n", " pbar.set_description(f\"loss: {loss_ema:.4f}, norm: {norm:.4f}, param_norm: {param_norm:.4f}, vb_loss: {info['vb_loss']:.4f}, ce_loss: {info['ce_loss']:.4f}\")\n", " optim.step()\n", " global_step += 1\n", "\n", " if global_step % 300 == 1:\n", " d3pm.eval()\n", "\n", " with torch.no_grad():\n", " x = d3pm.sample(torch.randint(0, N, (4, 1, 32, 32)).cuda())\n", " x_as_image = make_grid(x.float() / N, nrow=4)\n", " plt.figure()\n", " plt.imshow(x_as_image.permute(1, 2, 0).cpu().numpy())\n", " plt.show()\n", "\n", " d3pm.train()\n", " " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeuElEQVR4nO3dfWzV5f3/8Ve56RGlPbXc9EYKFlCYQHFjUDsVUSqlZoSbLsGbbDCZBFaI0HlXoyJufsvwHoNo4gJxEXEsAsNEUIotcRYclQ5R1wDpBEZbJlvPKcUW1l6/P5adn5W782lPefeU5yO5Es7n8+513p9csS8/55xeJ8Y55wQAwEXWzboBAMCliQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiR7WDXxXS0uLjh49qri4OMXExFi3AwDwyDmn+vp6paamqlu3c9/ndLoAOnr0qNLS0qzbAAC00+HDhzVgwIBznu+wl+BWrlypq6++WpdddpkyMzP1ySefhPVzcXFxHdUSAOAiutDv8w4JoLffflsFBQVasmSJPv30U40ePVo5OTk6duzYBX+Wl90AoGu44O9z1wHGjRvn8vPzQ4+bm5tdamqqKyoquuDPBgIBJ4nBYDAYUT4CgcB5f99H/A7o1KlTKi8vV3Z2duhYt27dlJ2drbKysjPqm5qaFAwGWw0AQNcX8QD6+uuv1dzcrKSkpFbHk5KSVFNTc0Z9UVGR/H5/aPABBAC4NJj/HVBhYaECgUBoHD582LolAMBFEPGPYfft21fdu3dXbW1tq+O1tbVKTk4+o97n88nn80W6DQBAJxfxO6DY2FiNGTNGxcXFoWMtLS0qLi5WVlZWpJ8OABClOuQPUQsKCjRr1iz98Ic/1Lhx4/Tiiy+qoaFBP//5zzvi6QAAUahDAmjmzJn65z//qSeeeEI1NTW6/vrrtWXLljM+mAAAuHTFOOecdRPfFgwG5ff7rdsAALRTIBBQfHz8Oc+bfwoOAHBpIoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAICJiAfQk08+qZiYmFZj+PDhkX4aAECU69ERk44YMULbtm37/0/So0OeBgAQxTokGXr06KHk5OSOmBoA0EV0yHtA+/fvV2pqqgYPHqx77rlHhw4dOmdtU1OTgsFgqwEA6PoiHkCZmZlas2aNtmzZolWrVqmqqko333yz6uvrz1pfVFQkv98fGmlpaZFuCQDQCcU451xHPkFdXZ0GDRqk559/XnPmzDnjfFNTk5qamkKPg8EgIQQAXUAgEFB8fPw5z3f4pwMSEhJ07bXX6sCBA2c97/P55PP5OroNAEAn0+F/B3TixAkdPHhQKSkpHf1UAIAoEvEAeuCBB1RaWqq///3v+vjjjzV9+nR1795dd911V6SfCgAQxSL+EtyRI0d011136fjx4+rXr59uuukm7dy5U/369Yv0UwEAoliHfwjBq2AwKL/fb90GAKCdLvQhBPaCAwCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJjr86xiAaNW9e/ewazvT9lELFizwVH/55ZeHXTts2DBPc+fn54dd++yzz3qa28sGx42NjZ7mXrZsmaf6pUuXeqrHf3EHBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATLAVDzrcwIEDw66NjY31NPePfvSjsGtvuukmT3MnJCSEXZuXl+dp7mh15MgRT/UrVqwIu3b69Ome5q6vrw+79q9//aunuUtLSz3Vo224AwIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiRjnnLNu4tuCwaD8fr91GziP73//+57qi4uLw65l7S++lpaWsGvvvfdeT3M3NDR4bSdsR48eDbv23//+t6e5KysrvbaDswgEAoqPjz/nee6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCih3UDiD5fffWVp/rjx4+HXXup7AW3a9cuT/V1dXVh1956662e5j516lTYtb///e89zQ2cD3dAAAATngNox44dmjJlilJTUxUTE6ONGze2Ou+c0xNPPKGUlBT16tVL2dnZ2r9/f6T6BQB0EZ4DqKGhQaNHj9bKlSvPen758uVasWKFXn31Ve3atUtXXHGFcnJy1NjY2O5mAQBdh+f3gHJzc5Wbm3vWc845vfjii3rsscc0depUSdIbb7yhpKQkbdy4UXfeeWf7ugUAdBkRfQ+oqqpKNTU1ys7ODh3z+/3KzMxUWVnZWX+mqalJwWCw1QAAdH0RDaCamhpJUlJSUqvjSUlJoXPfVVRUJL/fHxppaWmRbAkA0EmZfwqusLBQgUAgNA4fPmzdEgDgIohoACUnJ0uSamtrWx2vra0Nnfsun8+n+Pj4VgMA0PVFNIDS09OVnJys4uLi0LFgMKhdu3YpKysrkk8FAIhynj8Fd+LECR04cCD0uKqqShUVFUpMTNTAgQO1aNEi/eY3v9E111yj9PR0Pf7440pNTdW0adMi2TcAIMp5DqDdu3e32uqjoKBAkjRr1iytWbNGDz30kBoaGjR37lzV1dXppptu0pYtW3TZZZdFrmuY+te//uWp/sEHHwy79sc//rGnuffs2RN27YoVKzzN7UVFRYWn+ttvv91TfUNDQ9i1I0aM8DT3/fff76keiBTPATRhwgQ55855PiYmRk899ZSeeuqpdjUGAOjazD8FBwC4NBFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMx7nz76hgIBoPy+/3WbcCI16/jqK+vD7v2tdde8zT3nDlzwq796U9/6mnutWvXeqoHolEgEDjvf9PcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABM9rBsAvi0YDHbY3IFAoMPm/sUvfuGpft26dZ7qW1paPNUD0YA7IACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYiHHOOesmvi0YDMrv91u3gS7oiiuu8FS/efPmsGtvueUWT3Pn5uZ6qn///fc91QOdQSAQUHx8/DnPcwcEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBUPcA5DhgwJu/bTTz/1NHddXZ2n+g8//DDs2t27d3uae+XKlWHXdrJfF+jk2IoHANApEUAAABOeA2jHjh2aMmWKUlNTFRMTo40bN7Y6P3v2bMXExLQakydPjlS/AIAuwnMANTQ0aPTo0ed93Xjy5Mmqrq4OjbfeeqtdTQIAup4eXn8gNzf3gt9l4vP5lJyc3OamAABdX4e8B1RSUqL+/ftr2LBhmj9/vo4fP37O2qamJgWDwVYDAND1RTyAJk+erDfeeEPFxcX67W9/q9LSUuXm5qq5ufms9UVFRfL7/aGRlpYW6ZYAAJ2Q55fgLuTOO+8M/XvUqFHKyMjQkCFDVFJSookTJ55RX1hYqIKCgtDjYDBICAHAJaDDP4Y9ePBg9e3bVwcOHDjreZ/Pp/j4+FYDAND1dXgAHTlyRMePH1dKSkpHPxUAIIp4fgnuxIkTre5mqqqqVFFRocTERCUmJmrp0qXKy8tTcnKyDh48qIceekhDhw5VTk5ORBsHAEQ3z3vBlZSU6NZbbz3j+KxZs7Rq1SpNmzZNe/bsUV1dnVJTUzVp0iT9+te/VlJSUljzsxccotH06dM91a9evdpTfVxcnKd6Lx599NGwa9944w1Pc1dXV3ttB13IhfaC83wHNGHChPNuSLh161avUwIALkHsBQcAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEx43guuo7EXHC4Fo0aN8lT/3HPPhV17tu/dipTXXnvNU/3TTz8ddu0//vEPr+2gk7vQXnDcAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABNsxQNEgYSEhLBrp0yZ4mnu1atXh10bExPjae7t27eHXXv77bd7mhudH1vxAAA6JQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYC844BLX1NQUdm2PHj08zf2f//wn7NqcnBxPc5eUlHiqx8XHXnAAgE6JAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY8LavBoCIyMjI8FT/k5/8JOzasWPHeprb6/Y6XnzxxRdh1+7YsaPD+kDnxB0QAMCEpwAqKirS2LFjFRcXp/79+2vatGmqrKxsVdPY2Kj8/Hz16dNHvXv3Vl5enmprayPaNAAg+nkKoNLSUuXn52vnzp364IMPdPr0aU2aNEkNDQ2hmsWLF2vz5s1av369SktLdfToUc2YMSPijQMAopunF3+3bNnS6vGaNWvUv39/lZeXa/z48QoEAvrd736ntWvX6rbbbpMkrV69Wt/73ve0c+dO3XDDDZHrHAAQ1dr1HlAgEJAkJSYmSpLKy8t1+vRpZWdnh2qGDx+ugQMHqqys7KxzNDU1KRgMthoAgK6vzQHU0tKiRYsW6cYbb9TIkSMlSTU1NYqNjVVCQkKr2qSkJNXU1Jx1nqKiIvn9/tBIS0tra0sAgCjS5gDKz8/Xvn37tG7dunY1UFhYqEAgEBqHDx9u13wAgOjQpj8AWLBggd59913t2LFDAwYMCB1PTk7WqVOnVFdX1+ouqLa2VsnJyWedy+fzyefztaUNAEAU83QH5JzTggULtGHDBm3fvl3p6emtzo8ZM0Y9e/ZUcXFx6FhlZaUOHTqkrKysyHQMAOgSPN0B5efna+3atdq0aZPi4uJC7+v4/X716tVLfr9fc+bMUUFBgRITExUfH6+FCxcqKyuLT8ABAFrxFECrVq2SJE2YMKHV8dWrV2v27NmSpBdeeEHdunVTXl6empqalJOTo1deeSUizQIAuo4Y55yzbuLbgsGg/H6/dRuAhg0bFnbtwoULPc09ffp0T/Xneg/1YmtubvZUv23btrBr77jjDq/toJMLBAKKj48/53n2ggMAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACba9HUMQGfhZYuau+++29Pc+fn5YddeffXVnubuTHbv3h127dNPP+1p7j/96U9e28ElhDsgAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgLzh0uKSkpLBrR4wY4Wnul19+Oeza4cOHe5q7M9m1a1fYtc8884ynuTdt2hR2bUtLi6e5gfPhDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgKx4oMTHRU/1rr73mqf76668Pu3bw4MGe5u4sPv74Y0/1zz33nKf6rVu3hl37zTffeJobsMIdEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBdclMjMzPRU/+CDD4ZdO27cOE9zX3XVVZ7qOwuve6S99NJLYdf+3//9n6e5GxoaPNUDXRF3QAAAE54CqKioSGPHjlVcXJz69++vadOmqbKyslXNhAkTFBMT02rMmzcvok0DAKKfpwAqLS1Vfn6+du7cqQ8++ECnT5/WpEmTzng54b777lN1dXVoLF++PKJNAwCin6f3gLZs2dLq8Zo1a9S/f3+Vl5dr/PjxoeOXX365kpOTI9MhAKBLatd7QIFAQNKZX2j25ptvqm/fvho5cqQKCwt18uTJc87R1NSkYDDYagAAur42fwqupaVFixYt0o033qiRI0eGjt99990aNGiQUlNTtXfvXj388MOqrKzUO++8c9Z5ioqKtHTp0ra2AQCIUm0OoPz8fO3bt08fffRRq+Nz584N/XvUqFFKSUnRxIkTdfDgQQ0ZMuSMeQoLC1VQUBB6HAwGlZaW1ta2AABRok0BtGDBAr377rvasWOHBgwYcN7a//39yoEDB84aQD6fTz6fry1tAACimKcAcs5p4cKF2rBhg0pKSpSenn7Bn6moqJAkpaSktKlBAEDX5CmA8vPztXbtWm3atElxcXGqqamRJPn9fvXq1UsHDx7U2rVrdccdd6hPnz7au3evFi9erPHjxysjI6NDLgAAEJ08BdCqVask/fePTb9t9erVmj17tmJjY7Vt2za9+OKLamhoUFpamvLy8vTYY49FrGEAQNfg+SW480lLS1NpaWm7GsLZTZ8+vUPrO9KXX34Zdu3mzZs9zd3c3Bx27bPPPutp7rq6Ok/1ALxhLzgAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGAixl1of52LLBgMyu/3W7cBAGinQCCg+Pj4c57nDggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJTwG0atUqZWRkKD4+XvHx8crKytJ7770XOt/Y2Kj8/Hz16dNHvXv3Vl5enmprayPeNAAg+nkKoAEDBmjZsmUqLy/X7t27ddttt2nq1Kn6/PPPJUmLFy/W5s2btX79epWWluro0aOaMWNGhzQOAIhyrp2uvPJK9/rrr7u6ujrXs2dPt379+tC5L7/80klyZWVlYc8XCAScJAaDwWBE+QgEAuf9fd/m94Cam5u1bt06NTQ0KCsrS+Xl5Tp9+rSys7NDNcOHD9fAgQNVVlZ2znmampoUDAZbDQBA1+c5gD777DP17t1bPp9P8+bN04YNG3TdddeppqZGsbGxSkhIaFWflJSkmpqac85XVFQkv98fGmlpaZ4vAgAQfTwH0LBhw1RRUaFdu3Zp/vz5mjVrlr744os2N1BYWKhAIBAahw8fbvNcAIDo0cPrD8TGxmro0KGSpDFjxugvf/mLXnrpJc2cOVOnTp1SXV1dq7ug2tpaJScnn3M+n88nn8/nvXMAQFRr998BtbS0qKmpSWPGjFHPnj1VXFwcOldZWalDhw4pKyurvU8DAOhiPN0BFRYWKjc3VwMHDlR9fb3Wrl2rkpISbd26VX6/X3PmzFFBQYESExMVHx+vhQsXKisrSzfccENH9Q8AiFKeAujYsWP62c9+purqavn9fmVkZGjr1q26/fbbJUkvvPCCunXrpry8PDU1NSknJ0evvPJKhzQOAIhuMc45Z93EtwWDQfn9fus2AADtFAgEFB8ff87z7AUHADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMNHpAqiTbcwAAGijC/0+73QBVF9fb90CACACLvT7vNPtBdfS0qKjR48qLi5OMTExoePBYFBpaWk6fPjwefcWinZcZ9dxKVyjxHV2NZG4Tuec6uvrlZqaqm7dzn2f4/kL6Tpat27dNGDAgHOej4+P79KL/z9cZ9dxKVyjxHV2Ne29znA2le50L8EBAC4NBBAAwETUBJDP59OSJUvk8/msW+lQXGfXcSlco8R1djUX8zo73YcQAACXhqi5AwIAdC0EEADABAEEADBBAAEATERNAK1cuVJXX321LrvsMmVmZuqTTz6xbiminnzyScXExLQaw4cPt26rXXbs2KEpU6YoNTVVMTEx2rhxY6vzzjk98cQTSklJUa9evZSdna39+/fbNNsOF7rO2bNnn7G2kydPtmm2jYqKijR27FjFxcWpf//+mjZtmiorK1vVNDY2Kj8/X3369FHv3r2Vl5en2tpao47bJpzrnDBhwhnrOW/ePKOO22bVqlXKyMgI/bFpVlaW3nvvvdD5i7WWURFAb7/9tgoKCrRkyRJ9+umnGj16tHJycnTs2DHr1iJqxIgRqq6uDo2PPvrIuqV2aWho0OjRo7Vy5cqznl++fLlWrFihV199Vbt27dIVV1yhnJwcNTY2XuRO2+dC1ylJkydPbrW2b7311kXssP1KS0uVn5+vnTt36oMPPtDp06c1adIkNTQ0hGoWL16szZs3a/369SotLdXRo0c1Y8YMw669C+c6Jem+++5rtZ7Lly836rhtBgwYoGXLlqm8vFy7d+/WbbfdpqlTp+rzzz+XdBHX0kWBcePGufz8/NDj5uZml5qa6oqKigy7iqwlS5a40aNHW7fRYSS5DRs2hB63tLS45ORk98wzz4SO1dXVOZ/P59566y2DDiPju9fpnHOzZs1yU6dONemnoxw7dsxJcqWlpc65/65dz5493fr160M1X375pZPkysrKrNpst+9ep3PO3XLLLe7++++3a6qDXHnlle7111+/qGvZ6e+ATp06pfLycmVnZ4eOdevWTdnZ2SorKzPsLPL279+v1NRUDR48WPfcc48OHTpk3VKHqaqqUk1NTat19fv9yszM7HLrKkklJSXq37+/hg0bpvnz5+v48ePWLbVLIBCQJCUmJkqSysvLdfr06VbrOXz4cA0cODCq1/O71/k/b775pvr27auRI0eqsLBQJ0+etGgvIpqbm7Vu3To1NDQoKyvroq5lp9uM9Lu+/vprNTc3KykpqdXxpKQk/e1vfzPqKvIyMzO1Zs0aDRs2TNXV1Vq6dKluvvlm7du3T3FxcdbtRVxNTY0knXVd/3euq5g8ebJmzJih9PR0HTx4UI8++qhyc3NVVlam7t27W7fnWUtLixYtWqQbb7xRI0eOlPTf9YyNjVVCQkKr2mhez7NdpyTdfffdGjRokFJTU7V37149/PDDqqys1DvvvGPYrXefffaZsrKy1NjYqN69e2vDhg267rrrVFFRcdHWstMH0KUiNzc39O+MjAxlZmZq0KBB+sMf/qA5c+YYdob2uvPOO0P/HjVqlDIyMjRkyBCVlJRo4sSJhp21TX5+vvbt2xf171FeyLmuc+7cuaF/jxo1SikpKZo4caIOHjyoIUOGXOw222zYsGGqqKhQIBDQH//4R82aNUulpaUXtYdO/xJc37591b179zM+gVFbW6vk5GSjrjpeQkKCrr32Wh04cMC6lQ7xv7W71NZVkgYPHqy+fftG5douWLBA7777rj788MNWX5uSnJysU6dOqa6urlV9tK7nua7zbDIzMyUp6tYzNjZWQ4cO1ZgxY1RUVKTRo0frpZdeuqhr2ekDKDY2VmPGjFFxcXHoWEtLi4qLi5WVlWXYWcc6ceKEDh48qJSUFOtWOkR6erqSk5NbrWswGNSuXbu69LpK0pEjR3T8+PGoWlvnnBYsWKANGzZo+/btSk9Pb3V+zJgx6tmzZ6v1rKys1KFDh6JqPS90nWdTUVEhSVG1nmfT0tKipqami7uWEf1IQwdZt26d8/l8bs2aNe6LL75wc+fOdQkJCa6mpsa6tYj51a9+5UpKSlxVVZX785//7LKzs13fvn3dsWPHrFtrs/r6erdnzx63Z88eJ8k9//zzbs+ePe6rr75yzjm3bNkyl5CQ4DZt2uT27t3rpk6d6tLT090333xj3Lk357vO+vp698ADD7iysjJXVVXltm3b5n7wgx+4a665xjU2Nlq3Hrb58+c7v9/vSkpKXHV1dWicPHkyVDNv3jw3cOBAt337drd7926XlZXlsrKyDLv27kLXeeDAAffUU0+53bt3u6qqKrdp0yY3ePBgN378eOPOvXnkkUdcaWmpq6qqcnv37nWPPPKIi4mJce+//75z7uKtZVQEkHPOvfzyy27gwIEuNjbWjRs3zu3cudO6pYiaOXOmS0lJcbGxse6qq65yM2fOdAcOHLBuq10+/PBDJ+mMMWvWLOfcfz+K/fjjj7ukpCTn8/ncxIkTXWVlpW3TbXC+6zx58qSbNGmS69evn+vZs6cbNGiQu++++6Luf57Odn2S3OrVq0M133zzjfvlL3/prrzySnf55Ze76dOnu+rqarum2+BC13no0CE3fvx4l5iY6Hw+nxs6dKh78MEHXSAQsG3co3vvvdcNGjTIxcbGun79+rmJEyeGwse5i7eWfB0DAMBEp38PCADQNRFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDx/wAzgwe7F079ZQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# visualize first dataset\n", "x = next(iter(dataloader))[0][:1]\n", "x_as_image = make_grid(x.float(), nrow=1)\n", "plt.figure()\n", "plt.imshow(x_as_image.permute(1, 2, 0).cpu().numpy())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABkQAAAMtCAYAAADOtR3+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAq9ElEQVR4nO3df2yV9dn48avABGQtGxILlSpuyzYbYxuhEIZbRPtI1JBhtojbMhA3l23VSaoukiWwZSz+yGKI46jbMsU5psRlMLMsNYwtMA2TCmFh63RTWazD8kO3FjpFpef5w9jnywSl/bY9nIvXKzl/nPuc+9zX/Qcf2rx7n7uiWCwWAwAAAAAAILERpR4AAAAAAABgqAkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJDeqFIP0F+9vb2xe/fuqKysjIqKilKPAwAAAAAAlFCxWIwDBw5ETU1NjBhx7OtAyi6I7N69O2pra0s9BgAAAAAAcALp6OiIKVOmHPP1sgsilZWVEfHWiVVVVZV4GgAAAAAAoJS6u7ujtra2rx8cS9kFkbe/JquqqkoQAQAAAAAAIiLe8zYbbqoOAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6wx5EOjo64sILL4y6uro477zz4pFHHhnuEQAAAAAAgJPMqGE/4KhRsXLlymhoaIjOzs6YNm1aXHbZZTFu3LjhHgUAAAAAADhJDHsQmTx5ckyePDkiIiZNmhQTJ06MV155RRABAAAAAACGTL+/Mmvz5s0xb968qKmpiYqKili/fv073lMoFGLq1KkxZsyYmDlzZmzduvWon7Vt27Y4fPhw1NbW9ntwAAAAAACA49XvINLT0xP19fVRKBSO+vratWujpaUlli9fHtu3b4/6+vqYO3du7N2794j3vfLKK7Fw4cL40Y9+9K7HO3ToUHR3dx/xAAAAAAAA6I+KYrFYHPDOFRWxbt26mD9/ft+2mTNnRmNjY6xatSoiInp7e6O2tjauv/76uOWWWyLircjxP//zP3HttdfGF7/4xXc9xre//e34zne+847tXV1dUVVVNdDRAQCABF685Q8D2m/KbZ8c5EkAAIBS6e7ujvHjx79nN+j3FSLv5vXXX49t27ZFU1PT/x1gxIhoamqKLVu2REREsViMq6++Oi666KL3jCEREUuXLo2urq6+R0dHx2CODAAAAAAAnAQGNYjs378/Dh8+HNXV1Udsr66ujs7OzoiIeOKJJ2Lt2rWxfv36aGhoiIaGhti5c+cxP3P06NFRVVV1xAMAAAAAAKA/Rg33AS+44ILo7e0d7sMCAAAAAAAnsUG9QmTixIkxcuTI2LNnzxHb9+zZE5MmTRrMQwEAAAAAABy3QQ0ip5xySkybNi02btzYt623tzc2btwYs2bNGsxDAQAAAAAAHLd+f2XWwYMH49lnn+17vmvXrtixY0dMmDAhzjzzzGhpaYlFixbF9OnTY8aMGbFy5cro6emJxYsXD+rgAAAAAAAAx6vfQeSpp56KOXPm9D1vaWmJiIhFixbF6tWrY8GCBbFv375YtmxZdHZ2RkNDQ7S2tr7jRusAAAAAAADDpaJYLBZLPUR/dHd3x/jx46OrqyuqqqpKPQ4AAFBCL97yhwHtN+W2Tw7yJAAAQKkcbzcY1HuIAAAAAAAAnIgEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEivbIJIoVCIurq6aGxsLPUoAAAAAABAmSmbINLc3Bzt7e3R1tZW6lEAAAAAAIAyUzZBBAAAAAAAYKAEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACC9sgkihUIh6urqorGxsdSjAAAAAAAAZaZsgkhzc3O0t7dHW1tbqUcBAAAAAADKTNkEEQAAAAAAgIESRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAIL2yCSKFQiHq6uqisbGx1KMAAAAAAABlpmyCSHNzc7S3t0dbW1upRwEAAAAAAMpM2QQRAAAAAACAgRJEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANIrmyBSKBSirq4uGhsbSz0KAAAAAABQZsomiDQ3N0d7e3u0tbWVehQAAAAAAKDMlE0QAQAAAAAAGChBBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvbIJIoVCIerq6qKxsbHUowAAAAAAAGWmbIJIc3NztLe3R1tbW6lHAQAAAAAAykzZBBEAAAAAAICBEkQAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPTKJogUCoWoq6uLxsbGUo8CAAAAAACUmbIJIs3NzdHe3h5tbW2lHgUAAAAAACgzZRNEAAAAAAAABkoQAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASK9sgkihUIi6urpobGws9SgAAAAAAECZKZsg0tzcHO3t7dHW1lbqUQAAAAAAgDJTNkEEAAAAAABgoAQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvbIJIoVCIerq6qKxsbHUowAAAAAAAGWmbIJIc3NztLe3R1tbW6lHAQAAAAAAykzZBBEAAAAAAICBEkQAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgvZIEkSuuuCI++MEPxmc/+9lSHB4AAAAAADjJlCSI3HDDDfHTn/60FIcGAAAAAABOQiUJIhdeeGFUVlaW4tAAAAAAAMBJqN9BZPPmzTFv3ryoqamJioqKWL9+/TveUygUYurUqTFmzJiYOXNmbN26dTBmBQAAAAAAGJB+B5Genp6or6+PQqFw1NfXrl0bLS0tsXz58ti+fXvU19fH3LlzY+/evQMa8NChQ9Hd3X3EAwAAAAAAoD/6HUQuvfTSWLFiRVxxxRVHff3OO++Ma6+9NhYvXhx1dXVx7733xqmnnhr33XffgAa89dZbY/z48X2P2traAX0OAAAAAABw8hrUe4i8/vrrsW3btmhqavq/A4wYEU1NTbFly5YBfebSpUujq6ur79HR0TFY4wIAAAAAACeJUYP5Yfv374/Dhw9HdXX1Edurq6vj6aef7nve1NQUf/rTn6KnpyemTJkSjzzySMyaNeuonzl69OgYPXr0YI4JAAAAAACcZAY1iByv3/72t6U4LAAAAAAAcJIa1K/MmjhxYowcOTL27NlzxPY9e/bEpEmTBvNQAAAAAAAAx21Qg8gpp5wS06ZNi40bN/Zt6+3tjY0bNx7zK7EAAAAAAACGWr+/MuvgwYPx7LPP9j3ftWtX7NixIyZMmBBnnnlmtLS0xKJFi2L69OkxY8aMWLlyZfT09MTixYsHdXAAAAAAAIDj1e8g8tRTT8WcOXP6nre0tERExKJFi2L16tWxYMGC2LdvXyxbtiw6OzujoaEhWltb33GjdQAAAAAAgOFSUSwWi6Ueoj+6u7tj/Pjx0dXVFVVVVaUeBwAAKKEXb/nDgPabctsnB3kSAACgVI63GwzqPUQAAAAAAABORIIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApCeIAAAAAAAA6QkiAAAAAABAeoIIAAAAAACQniACAAAAAACkJ4gAAAAAAADpCSIAAAAAAEB6gggAAAAAAJCeIAIAAAAAAKQniAAAAAAAAOkJIgAAAAAAQHqCCAAAAAAAkJ4gAgAAAAAApFc2QaRQKERdXV00NjaWehQAAAAAAKDMlE0QaW5ujvb29mhrayv1KAAAAAAAQJkZVeoB+qtYLEZERHd3d4knAQAASu3AoZ4B7ef3CQAAyOPtn+/f7gfHUnZB5MCBAxERUVtbW+JJAACAsrWy1AMAAACD7cCBAzF+/Phjvl5RfK9kcoLp7e2N3bt3R2VlZVRUVJR6HChb3d3dUVtbGx0dHVFVVVXqcYCErDPAULPOAEPNOgMMJWsMDJ5isRgHDhyImpqaGDHi2HcKKbsrREaMGBFTpkwp9RiQRlVVlf90gSFlnQGGmnUGGGrWGWAoWWNgcLzblSFvK5ubqgMAAAAAAAyUIAIAAAAAAKQniMBJavTo0bF8+fIYPXp0qUcBkrLOAEPNOgMMNesMMJSsMTD8yu6m6gAAAAAAAP3lChEAAAAAACA9QQQAAAAAAEhPEAEAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBJI6cOBALFmyJM4666wYO3ZsfOITn4i2trZ33efQoUPxrW99K84666wYPXp0TJ06Ne67775hmhgoNwNZZ9asWRP19fVx6qmnxuTJk+Oaa66Jl19+eZgmBk5kmzdvjnnz5kVNTU1UVFTE+vXrj3i9WCzGsmXLYvLkyTF27NhoamqKv//97+/5uYVCIaZOnRpjxoyJmTNnxtatW4foDIAT3VCsM7feems0NjZGZWVlnH766TF//vx45plnhvAsgBPZUP0887bbbrstKioqYsmSJYM7OJxEBBFI6stf/nJs2LAhHnzwwdi5c2dccskl0dTUFP/85z+Puc+VV14ZGzdujJ/85CfxzDPPxEMPPRQf+9jHhnFqoJz0d5154oknYuHChfGlL30p/vKXv8QjjzwSW7dujWuvvXaYJwdORD09PVFfXx+FQuGor99xxx1x1113xb333htPPvlkjBs3LubOnRuvvfbaMT9z7dq10dLSEsuXL4/t27dHfX19zJ07N/bu3TtUpwGcwIZindm0aVM0NzfHH//4x9iwYUO88cYbcckll0RPT89QnQZwAhuKdeZtbW1t8cMf/jDOO++8wR4bTioVxWKxWOohgMH16quvRmVlZfzqV7+Kyy+/vG/7tGnT4tJLL40VK1a8Y5/W1ta46qqr4vnnn48JEyYM57hAGRrIOvP9738/7rnnnnjuuef6tv3gBz+I22+/PV588cVhmRsoDxUVFbFu3bqYP39+RLz115Q1NTVx4403xk033RQREV1dXVFdXR2rV6+Oq6666qifM3PmzGhsbIxVq1ZFRERvb2/U1tbG9ddfH7fccsuwnAtwYhqsdea/7du3L04//fTYtGlTfOpTnxqq8YEyMJjrzMGDB+P888+Pu+++O1asWBENDQ2xcuXKYTgLyMcVIpDQm2++GYcPH44xY8YcsX3s2LHx+OOPH3WfRx99NKZPnx533HFHnHHGGfHRj340brrppnj11VeHY2SgzAxknZk1a1Z0dHTEb37zmygWi7Fnz574xS9+EZdddtlwjAyUsV27dkVnZ2c0NTX1bRs/fnzMnDkztmzZctR9Xn/99di2bdsR+4wYMSKampqOuQ9w8hrIOnM0XV1dERH+yAx4h/+fdaa5uTkuv/zyI/YFBmZUqQcABl9lZWXMmjUrvvvd78Y555wT1dXV8dBDD8WWLVviIx/5yFH3ef755+Pxxx+PMWPGxLp162L//v3x9a9/PV5++eW4//77h/kMgBPdQNaZ2bNnx5o1a2LBggXx2muvxZtvvhnz5s075uXkAG/r7OyMiIjq6uojtldXV/e99t/2798fhw8fPuo+Tz/99NAMCpStgawz/623tzeWLFkSs2fPjnPPPXfQZwTK20DXmYcffji2b9/+nvdrBI6PK0QgqQcffDCKxWKcccYZMXr06Ljrrrvic5/7XIwYcfR/9r29vVFRURFr1qyJGTNmxGWXXRZ33nlnPPDAA64SAY6qv+tMe3t73HDDDbFs2bLYtm1btLa2xj/+8Y/46le/OsyTAwAMvubm5vjzn/8cDz/8cKlHAZLo6OiIG264IdasWfOOq/OBgRFEIKkPf/jDsWnTpjh48GB0dHTE1q1b44033ogPfehDR33/5MmT44wzzojx48f3bTvnnHOiWCz6bn/gqPq7ztx6660xe/bsuPnmm+O8886LuXPnxt133x333XdfvPTSS8M8PVBOJk2aFBERe/bsOWL7nj17+l77bxMnToyRI0f2ax/g5DWQdeb/dd1118Wvf/3r+P3vfx9TpkwZkhmB8jaQdWbbtm2xd+/eOP/882PUqFExatSo2LRpU9x1110xatSoOHz48JDPDdkIIpDcuHHjYvLkyfGvf/0rHnvssfj0pz991PfNnj07du/eHQcPHuzb9re//S1GjBjhB3rgXR3vOvOf//znHVePjBw5MiLeusEgwLGcffbZMWnSpNi4cWPftu7u7njyySdj1qxZR93nlFNOiWnTph2xT29vb2zcuPGY+wAnr4GsMxFv/Qxz3XXXxbp16+J3v/tdnH322cMxLlCGBrLOXHzxxbFz587YsWNH32P69OnxhS98IXbs2NH3+xRw/AQRSOqxxx6L1tbW2LVrV2zYsCHmzJkTH//4x2Px4sUREbF06dJYuHBh3/s///nPx2mnnRaLFy+O9vb22Lx5c9x8881xzTXXxNixY0t1GsAJrL/rzLx58+KXv/xl3HPPPfH888/HE088Ed/4xjdixowZUVNTU6rTAE4QBw8e7PtFP+KtG4/u2LEjXnjhhaioqIglS5bEihUr4tFHH42dO3fGwoULo6amJubPn9/3GRdffHGsWrWq73lLS0v8+Mc/jgceeCD++te/xte+9rXo6enpW6eAk8tQrDPNzc3xs5/9LH7+859HZWVldHZ2Rmdnp68dhpPUYK8zlZWVce655x7xGDduXJx22mnuVQQD5KbqkFRXV1csXbo0XnzxxZgwYUJ85jOfie9973vxvve9LyIiXnrppXjhhRf63v/+978/NmzYENdff31Mnz49TjvttLjyyitjxYoVpToF4ATX33Xm6quvjgMHDsSqVavixhtvjA984ANx0UUXxe23316qUwBOIE899VTMmTOn73lLS0tERCxatChWr14d3/zmN6Onpye+8pWvxL///e+44IILorW19Yjv037uuedi//79fc8XLFgQ+/bti2XLlkVnZ2c0NDREa2vrO25mCpwchmKdueeeeyIi4sILLzziWPfff39cffXVQ3cywAlpKNYZYHBVFH1HBQAAAAAAkJyvzAIAAAAAANITRAAAAAAAgPQEEQAAAAAAID1BBAAAAAAASE8QAQAAAAAA0hNEAAAAAACA9AQRAAAAAAAgPUEEAAAAAABITxABAAAAAADSE0QAAAAAAID0BBEAAAAAACC9/wUt71o0Sbl9QgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# check parameter distribution\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "def plot_param_dist(model):\n", " plot_dist = {}\n", " for name, param in model.named_parameters():\n", " if param.requires_grad:\n", " plot_dist[name] = param.detach().cpu().numpy().flatten()\n", " # sample 1000\n", " # replace nan with 1000\n", " plot_dist[name] = np.where(np.isnan(plot_dist[name]), 10, plot_dist[name])\n", " plot_dist[name] = np.random.choice(plot_dist[name], 1000)\n", " \n", " plt.figure(figsize=(20, 10))\n", "\n", " for idx, (name, dist) in enumerate(plot_dist.items()):\n", " \n", " plt.hist(dist, bins=100, density=True)\n", " \n", " # log y\n", " plt.yscale('log')\n", "\n", " #plt.legend()\n", " plt.show()\n", "\n", "\n", "plot_param_dist(d3pm.x0_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAACxCAYAAADwMnaUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkhklEQVR4nO3dfWxV5R0H8G8LtKDQy4uhtSuFbmNDJ1YHWjvMcNrJiFEYZFPDRlESoyuO0mRTtqHGzVVdNl58gbktwDIZjkRgkKhhRUpMylsRJuIqy5igeMucay+ilI4++2Pzjvvr8f7O755T7mn5fpKb9Nzz9pznvPTJeX739+Q45xyIiIiIIiA32wUgIiIi+hgbJkRERBQZbJgQERFRZLBhQkRERJHBhgkRERFFBhsmREREFBlsmBAREVFksGFCREREkcGGCREREUUGGyZEREQUGT3WMHnqqacwZswYDBw4EBUVFdi1a1dP7YqIiIj6iJyeGCvnueeew+zZs7FixQpUVFRgyZIlWLduHVpaWjBy5Mi063Z1deHYsWMYMmQIcnJywi4aERER9QDnHE6cOIHi4mLk5mb+3qNHGiYVFRW46qqr8OSTTwL4b2Nj1KhRuPfee3H//fenXfftt9/GqFGjwi4SERERnQNHjx5FSUlJxuuH3pVz+vRpNDc3o6qq6v87yc1FVVUVmpqaui3f0dGBRCKR/HCwYyIiot5ryJAhgdYPvWHy3nvv4cyZMygsLEz5vrCwEPF4vNvy9fX1iMViyU9paWnYRSIiIqJzJGgYRtZ/lbNw4UK0t7cnP0ePHs12kYiIiChL+oe9wYsuugj9+vVDa2tryvetra0oKirqtnx+fj7y8/PDLgYRERH1QqG/McnLy8OECRPQ0NCQ/K6rqwsNDQ2orKwMe3dERETUh4T+xgQA6urqUF1djYkTJ+Lqq6/GkiVLcPLkSdxxxx09sTsiIiLqI3qkYXLrrbfiH//4Bx544AHE43FcccUVePHFF7sFxGYqFoulTI8ZMyZluq2tLe20XN7L3//+95TpoUOHpl1e7kNb3lpmKezy+NmnRtuHnN/Y2Jh2+dGjR6dMa8dkPWY/x+un3oIIet3I61TS6shrGeu1L+dfccUVKdOrV69Ou355eXnKdNDrMAzaedGm9+3bl3a+9VqWMnl+aNeOVobrrrsu7fr79+/vXtCzyPvZWj5NJs+8ntiGZX3rvZYJ6zFp18Fbb70VvFCKHmmYAMC8efMwb968nto8ERER9UFZ/1UOERER0cfYMCEiIqLI6LGunJ40Z86clGnZByb7dyWtf9iL1hdonQ47psS6Pz/kNoLGN1hp/eja/uT6Wr+/F61/dvr06SnT27ZtU7eZbvtyexs2bEi7vGStA6D7eZSxBNryWoyKRt7Psoxye2Fcd1o9yjgZLYZEK4Oclscsz7MWB5dJbELQY5bXhbzWtRgTuX1ZHnlMPfG8sT77ZZnledfqVIvTsZ4Trc68aM9ROV/uU57ncxFjwjcmREREFBlsmBAREVFksGFCREREkdErY0y0/l2t3172K3rFBch+N61vUOv7C9qXKGnHILcnj8erb1UrszV3gpyWZdLymGjry+tAq0NZfq9+fFmPchl5LVn7+rXzoPWzy/1L2nXgtf6SJUvSlkETNPeDLLM1ZkWLffCKmdHi0mS9yTJo16IW46U9w7T9yxgV7br1KoPcpiyTtg9rPJVWHu260eK5tOeF1z7ktbFq1aq0+9TOq/UZqj2ztPX93CvyGLU4GTlf7lOLJQoD35gQERFRZLBhQkRERJHBhgkRERFFRo5zzmW7EGdLJBLdxsKRtLE1tH45P3EB1rEqgo55YM0zEnTsjUyO2Vomrf+zvb097fas51krTyaxENZxZIKOgWLVE/vX1tHiK6x90pMnT06Ztsa4+MnVYt2G9Txr2wt63q3PND9lsj6ztBw7Wn4L65hI1ntP8tq+fObI/zXW82ytMxnDEjTXU9jPE0DPoeMnxqS9vR0FBQUZl4FvTIiIiCgy2DAhIiKiyGDDhIiIiCKDDRMiIiKKjF4Z/Dp69Oi0860BTF7zMwkus5TBGrgVNDivJwb5ClonWhCVdp6tMgkA1urVGkyn7U8KGkTdE8Gv2vKSFhQp6yzovXcuaMccdtC0dft+BA28lLTzrAU5B73O/NShFritCZpMMOwfGGj/J/ywlsnPIH4MfiUiIqI+gw0TIiIiigw2TIiIiCgyeuUgfkEH0Mok3kJLmGTtE9b69eTAS9bkQpnEg1gHqLKy9mVqsQbWAa606wTQ612yDsKnzdeSyFmPORPW5GCyzmTiLY024KRkvTe9yt/TMWDatWaNo9GSF2aS0E0bQE4egzYYnMYaw5bJ/avRyqwdozXGI2hskfX/ip9ltGvtoYceSjt9LvCNCREREUUGGyZEREQUGWyYEBERUWT0yhgTrd9P60OTfZlefeLWmJJMcqWk2582MJp2jJnkBLDGL1jjH4LmPQmaE8RPPMa2bdvSLmOtV61PO+ycIXJ9beA1r21a+8XlMdbW1qZML1iwoHtB05DnQNLuRT/5MayxBFq8hSyzNZ+E9RxosUpe14k8hqDxU3K+lt9CLi+vzSVLlqRd3hpL5HUOtPtBqyM5CJ8sc9DYQWvMiZ94Kuv9Io9J1tnSpUu77SNsfGNCREREkcGGCREREUUGGyZEREQUGb0yxiRoX6PsV/Tqj7X27cnfest+di1nQCZ9h5b1pUzGbLDGtQSNMdHqwBpbJNeX/b1erH35WmyQFp9hFcb4IUHH75Dryz5qjTX3i7a+n9wOWhyLJLdpjYPR5gcdcyWT3C1Bc+gELbP1XrHmVfKqA7lP6zHJ57q1DuR1EzR2MZOcXNo2tPipc4FvTIiIiCgy2DAhIiKiyDA3TLZv346bb74ZxcXFyMnJ6dYt4pzDAw88gIsvvhiDBg1CVVUVDh06FFZ5iYiIqA8zx5icPHkS5eXluPPOOzFjxoxu8x9//HEsW7YMq1evRllZGRYtWoQpU6bg4MGDGDhwYCiFlqz5N2Qfmp/ffmu5T6z96tr2wx6Xxs/2tWWs/alB4yes/a+y/JI1JsVrm/KYtPE9tDLIHAEyT4I1H0YmcQHaMvIYrf3kWn4La04Q7bqU25N16rVNydovr+UhCjoGU9AcQ372Yc1fE5R2DNp1Zr03vLah3Y9yeeu1r+VNseY10Xgtrz2j5D5lGTOp56DMDZOpU6di6tSpnvOcc1iyZAl+9KMfYdq0aQCA3/72tygsLMSGDRtw2223BSstERER9WmhxpgcPnwY8XgcVVVVye9isRgqKirQ1NTkuU5HRwcSiUTKh4iIiM5PoTZM4vE4AKCwsDDl+8LCwuQ8qb6+HrFYLPkZNWpUmEUiIiKiXiTreUwWLlyIurq65HQikVAbJ9p4IFo/YCaxE1q/e9Df+Gu/FbfGMmjHOGfOnG7r+Mnvko52XuT2tdgDaz+/dsxaXABgH9tGy8EhyeXlebfGGmnHqO3fixZTYs3xo7HmBLHWsR/WZ4Q1BkSL/dFimYLmOfIqk1bvYZ9nub4WyyC3L3OIaM95P2XQ6lEbO0c+4+R8OS1jSuR5l/mwrM9gP+MDWfcRdmyRH6G+MSkqKgIAtLa2pnzf2tqanCfl5+ejoKAg5UNERETnp1AbJmVlZSgqKkJDQ0Pyu0QigZ07d6KysjLMXREREVEfZO7K+eCDD/DXv/41OX348GHs27cPw4cPR2lpKWpra/GTn/wEY8eOTf5cuLi4uNvrJCIiIiIpxznnLCts27YNX/nKV7p9X11djVWrVsE5hwcffBDPPPMM2tracO211+Lpp5/G5z73OV/bTyQSiMViaZcpLy9PO1+L99DGN/FaRusD1n6Tbx3XIWgsg9Yv6LW+dcwSrV61etZiTOR1YK1j2Z/rJ6+KNUZDy19h7aeX/ehye9Z+eW2+13daLgfreCP79+/vts+zjR49Ou36Vpn0iVvjWHo6/kkrj/b88RLGuCvptqfdz/K5HTRuRnseeLHGgFhjzqx1rD0jrXFyXnFDQf83Sdp5BoD29vZAYRnmNybXXXcd0rVlcnJy8PDDD+Phhx/OuFBERER0fuJYOURERBQZbJgQERFRZGQ9j0kmtLEwtH4/LQeJ13dB+xq12ARrnhKZh8RrPJB02/PavhZTosUryOXlMcn1tb7KoPETWp4Gr355a1yNnxiOs2ljY8gYEut4QVr5vVjzlmi5Vqx91kHzZWSSh0E7b0H7/q3PHGtMidyfjE3yGrtLiyWS503uU+5De+ZotPPq5zl9Nj/Xodym9dq33u9S2HEzfnJyaceg5WaRdeInxiQovjEhIiKiyGDDhIiIiCKDDRMiIiKKDHMek57mJ4/J5MmTU6a1/tqgeRG8tqH17Wl9zHL9oONOBI1F8CpT0HgGrb9048aN3cpwNpnfImjOkEzy1wQdH0T2y8txKrQ6sl7bYVz71tgi7brQ8phoeYkkuT8Zt6PFFvmh5QWxxpRpeZA01txMfuLmZJm05bX73Xqegz5TJa/xvyQtHlHSxvsKOzeM9f+IFhPjZx/WODrtuQ0Ez2PCNyZEREQUGWyYEBERUWSwYUJERESR0SvzmFj7PjPJYxI03sLaLy+F/Xv5TMZH0MaR0Oo5kzFLzqbFX2jT8jzL/uFMxtaw9hlrY9v4iXuxkHXuJ79FJnlAzib79uXyWuyBNWZF8jNOjCSvZUnru7fm1LHmdtFiCSQ/97e2jPZcDCNWLx3tOtByPVnjR7yWkfeflqvFep6sOYCsz1Q/512brz03zwW+MSEiIqLIYMOEiIiIIoMNEyIiIoqMPpHHxJozwM+4MVouB2sMirWfz9onbc1v4acM1t+7W2ljLsjzLM+B7M+15nrwquNMxpqx7FMTNJdDJtehNcbEei03NjamXV7mt9DuPY2f8gbdh7bPnjhP6eZnksfEmjNHi4+w3s/atW6NgfFDW8caO2SNh9KeJ1oOEW1/XjEuWgyYFnsn47FWr17dbR8S85gQERFRn8GGCREREUUGGyZEREQUGWyYEBERUWT0iQRrkhbsqi0P2INdwx4AT7JuT/KzfNiBljLoSgZZacFyUtA6kgPoeSUbswbcWROaWQci1FivUy/aebYm3rImidMC/iRroKjX8QUNZtcS5WnHYE1iJfevDS7n57xr16KkBQxr97P1PMnty+eJvLf8BCAHfY5atydZk1Rq/ATny4BYazLQbOAbEyIiIooMNkyIiIgoMtgwISIiosjolTEmWgI1SYsb8BNvofXHav10WkxK2P16YSQfsiZkCjshmzWmRBuYTfZJ90SCNTnoV9CkUNY6yyTpnZZIy5poS5ZBiz2w1okWF+BnYLWg50Hbh5b4SsaIWGNerDEpXtuwxr1pz0QtkZ58ZmrPcbl/GSMmZZJkzjpfqyPt/4QWUybrRLI+EwE9Nkd7bgYdWDQTfGNCREREkcGGCREREUUGGyZEREQUGb0yxkT2A1r73bT4ED/LaPEYWg4PrYzafOvv5eX+/fS/hv37dmu8hjWfhYzvsPbfem1Tm6/F1Vj7pK05RbQ4AHkdyDryWkfS+v6D0gak0+5FWR5573nFmGi0e0Hrd5f7lMck52sxZ1reFMmrfFo9a/UUNF+Ndu1q86287ndr3Jyk3V9anWjxW3J9a3m8yJgS7bzL+Kds4BsTIiIiigxTw6S+vh5XXXUVhgwZgpEjR2L69OloaWlJWebUqVOoqanBiBEjMHjwYMycOROtra2hFpqIiIj6JlPDpLGxETU1NdixYwe2bNmCzs5O3HjjjTh58mRymQULFmDTpk1Yt24dGhsbcezYMcyYMSP0ghMREVHfY4oxefHFF1OmV61ahZEjR6K5uRlf/vKX0d7ejt/85jdYs2YNrr/+egDAypUrcckll2DHjh245pprQim01gem9e9qeRkAPQ+BZO2vteZBkGWWfZtyvlYHfspkzWOQyT7TseaCkLQ68WLtB9fiH6x9xtr+tetEkrFFXte63IfWbx52LJIWM2bt19fGJ/KzD4113Bc/8U3pti+vdWt+DK8yafvUjlHW6/79+9MuL4/Ba6yqs1njbPzEHmr3lzUGRYuD0Z4Hsk60543kJw4naPxiNsbOCRRj0t7eDgAYPnw4AKC5uRmdnZ2oqqpKLjNu3DiUlpaiqakpyK6IiIjoPJDxr3K6urpQW1uLSZMm4bLLLgMAxONx5OXldWuBFRYWIh6Pe26no6MDHR0dyelEIpFpkYiIiKiXy/iNSU1NDQ4cOIC1a9cGKkB9fT1isVjyM2rUqEDbIyIiot4rozcm8+bNw+bNm7F9+3aUlJQkvy8qKsLp06fR1taW8taktbUVRUVFnttauHAh6urqktOJREJtnMhcDFp/sdZvmMlv/q19lVq8hBa7EHYMi59YA0n7zb0W7xCUVofaeda2B+j13tOxCdY4He2YMxk/xJpzRyujNlaOFqMi61yeIy1vide9Yq1Xa34b7f7TYsQkGY+h5cPwc11q+9Rij6x5RrSYEuszNZNj1u5f6/2tra8917V7TYtJ0ZbPRNh5ijJhemPinMO8efOwfv16bN26FWVlZSnzJ0yYgAEDBqChoSH5XUtLC44cOYLKykrPbebn56OgoCDlQ0REROcn0xuTmpoarFmzBhs3bsSQIUOScSOxWAyDBg1CLBbD3LlzUVdXh+HDh6OgoAD33nsvKisrQ/tFDhEREfVdpobJ8uXLAXR/pbdy5crk69TFixcjNzcXM2fOREdHB6ZMmYKnn346lMISERFR35bjnHPZLsTZEokEYrFY2mUmT56cdr61P9gP2Q8u+w5l3IvWby77xbXxDLSYFet4QT3BGtfS2NiYdnvl5eUp09Y8DJnkFLGOz6PFFgSNVbCO5aHFW3itb43JsuYZWb16dfeCnmX06NFp59fW1qZMa7EKkp94qqBjHmn7tMYOWc9JJrRj1JaX19rSpUvTri+f29Zxo4LeG5msI49RG5/Legza8tq9pcWoAPZrUcuZoz23gf+mEgkSlsGxcoiIiCgy2DAhIiKiyGDDhIiIiCKjV8aYaH3SWj+i7LN+6KGH1G0EzXtgzS+RSf9pOn5yjFh/v671l2rLa/ktpk2blnZ9ydpnHYag8RjWcSqs141WHi9aPIW1DNbzbB3rKox4DO3asF5LQc+LNi6M5GesHCt5DFoODS2WSMaMaax1lMkYL0Gfs9b1g+7Pem/6WcYa96LdzwBjTIiIiKgPYcOEiIiIIoMNEyIiIoqMXhljYs1vkUn8hpaPImgeEbm+/D160DFYrH3mfrahjV1hzc2wf//+tMtrsUQa6zkC7LFEQVn7dyVr7gmv47PePzKXghbfEPQ8B43D8UNbJ2i8kjV+S8tXoV3bXrlb5HnSxnHRxvOR+9RiD7TzHDQux+uYNVrcilZnQePygsZHhZHLRdLKoN3PAGNMiIiIqA9hw4SIiIgigw0TIiIiigzTIH5RZe2b9NMHrfUdyv5XmXtB63e3jsWhjdUj96fxOj5rPIP19/BWQXPHyDrR8jB4rRO0TzhoXhJr7pZMWK9N67UWdP8ard/fqw6t1651rBvJet79jIFiFTR/hbY9P/kt0gk6Vo7kdR3JfVjHqtGEXceaMOLmrGX2E2MSFN+YEBERUWSwYUJERESRwYYJERERRUavjDHR4jvmzJmTMi37a2UfWia/+V+yZEnaMkpamYPGHli3nwlrn6+ctsYmhDHuS7r9+6kT6/g/Whm1/l5rngRrnfjpd5e0Ywwag2KNTbJey173t3XcFXl/aXEsQeMl5DNs1apVSMfPvWaNtdOek/IZo7HGd8hjkHUi5/t5HmjjpMk6keddlln7PxA0P441T4mfnFzac1D7X3Iu8I0JERERRQYbJkRERBQZbJgQERFRZLBhQkRERJHRKwfxmzZtWsq0FmBoDUD0WiaToKOzyeReWsImLZDLmgwtkzrQtqmVSQsIbGxsTLu9xYsXp0zLAMCwk1x5LaPRAj+DDgqYSRKpdOv7qbOgwahyWku8NX/+/JRpeW9o+9cC072OWX6nBXYGHfgsaCI9beBErTxetGeENeDXOlij3L8MbtUCU7Vz5nXdWIOeteemlnRS27/1eWNNlua1D2ugttym9twGOIgfERER9SFsmBAREVFksGFCREREkdErE6xZE/9o/b1+BpvTkv8ETS4m+x7lMcn+Xpk0zjromFffpuzjletoA4tp9WrtT5UxJVryMS0pnp9EYFq/u6SdV61OrLFEWv+w1ifeE/FVWiyCFmMil9cSaWkD3snz7lV+GfMly6DN1+pdez4EHSxOri/L63WvytgbOW2NNZD1rMWYyP1Z46Os17rX9rVltFhA7X603q8aa2I+r2s9aKLKMJJzWvGNCREREUUGGyZEREQUGWyYEBERUWT0yjwm5eXlaedb+2/9xFtYc2hY+/q1eAgtpiTowGleZdL6L7XBobQBsZYuXZq2jPI8a33M1vMuywPoA1Zpxxw0xkOLt5DXQdBzBOi5GOQ6Yee3kHlMtHguSTtmP33k8pitcWphD+JnzXviJ8+JNfeRxhpLJPOYaDEk2nnXyutnUL+gzxTtWrPmq9Hi6LTyeNWJ9dqT5PyNGzemXR5gHhMiIiLqQ0wNk+XLl+Pyyy9HQUEBCgoKUFlZiRdeeCE5/9SpU6ipqcGIESMwePBgzJw5E62traEXmoiIiPomU8OkpKQEjz76KJqbm7Fnzx5cf/31mDZtGl5//XUAwIIFC7Bp0yasW7cOjY2NOHbsGGbMmNEjBSciIqI+yAU0bNgw9+tf/9q1tbW5AQMGuHXr1iXnvfHGGw6Aa2pq8r299vZ2B4Affvjhhx9++OmFn/b29kDtioxjTM6cOYO1a9fi5MmTqKysRHNzMzo7O1FVVZVcZty4cSgtLUVTU9MnbqejowOJRCLlQ0REROcnc8Pktddew+DBg5Gfn4+7774b69evx6WXXop4PI68vLxuEbyFhYWIx+OfuL36+nrEYrHkZ9SoUeaDICIior7B3DD5/Oc/j3379mHnzp245557UF1djYMHD2ZcgIULF6K9vT35OXr0aMbbIiIiot7NPFZOXl4ePvvZzwIAJkyYgN27d2Pp0qW49dZbcfr0abS1taW8NWltbUVRUdEnbi8/Px/5+fn2khMREVGfEziPSVdXFzo6OjBhwgQMGDAADQ0NyXktLS04cuQIKisrg+6GiIiIzgOmNyYLFy7E1KlTUVpaihMnTmDNmjXYtm0bXnrpJcRiMcydOxd1dXUYPnw4CgoKcO+996KyshLXXHNNT5WfiIiI+hBTw+T48eOYPXs23n33XcRiMVx++eV46aWX8NWvfhUAsHjxYuTm5mLmzJno6OjAlClT8PTTT5sK5KKVIZ+IiIgMgv4fj9xYOW+//TZ/mUNERNRLHT16FCUlJRmvH7mGSVdXF44dOwbnHEpLS3H06NFAgwGd7xKJBEaNGsV6DIB1GBzrMBysx+BYh8F9Uh0653DixAkUFxcjNzfzEFbzr3J6Wm5uLkpKSpKJ1j4el4eCYT0GxzoMjnUYDtZjcKzD4LzqMBaLBd4uRxcmIiKiyGDDhIiIiCIjsg2T/Px8PPjgg0y+FhDrMTjWYXCsw3CwHoNjHQbX03UYueBXIiIiOn9F9o0JERERnX/YMCEiIqLIYMOEiIiIIoMNEyIiIoqMyDZMnnrqKYwZMwYDBw5ERUUFdu3ale0iRVZ9fT2uuuoqDBkyBCNHjsT06dPR0tKSssypU6dQU1ODESNGYPDgwZg5cyZaW1uzVOLoe/TRR5GTk4Pa2trkd6xDf9555x1861vfwogRIzBo0CCMHz8ee/bsSc53zuGBBx7AxRdfjEGDBqGqqgqHDh3KYomj5cyZM1i0aBHKysowaNAgfOYzn8GPf/zjlPFHWIeptm/fjptvvhnFxcXIycnBhg0bUub7qa/3338fs2bNQkFBAYYOHYq5c+figw8+OIdHkX3p6rGzsxP33Xcfxo8fjwsvvBDFxcWYPXs2jh07lrKNMOoxkg2T5557DnV1dXjwwQexd+9elJeXY8qUKTh+/Hi2ixZJjY2NqKmpwY4dO7BlyxZ0dnbixhtvxMmTJ5PLLFiwAJs2bcK6devQ2NiIY8eOYcaMGVksdXTt3r0bv/zlL3H55ZenfM861P3rX//CpEmTMGDAALzwwgs4ePAgfv7zn2PYsGHJZR5//HEsW7YMK1aswM6dO3HhhRdiypQpOHXqVBZLHh2PPfYYli9fjieffBJvvPEGHnvsMTz++ON44oknksuwDlOdPHkS5eXleOqppzzn+6mvWbNm4fXXX8eWLVuwefNmbN++HXfddde5OoRISFePH374Ifbu3YtFixZh7969eP7559HS0oJbbrklZblQ6tFF0NVXX+1qamqS02fOnHHFxcWuvr4+i6XqPY4fP+4AuMbGRuecc21tbW7AgAFu3bp1yWXeeOMNB8A1NTVlq5iRdOLECTd27Fi3ZcsWN3nyZDd//nznHOvQr/vuu89de+21nzi/q6vLFRUVuZ/97GfJ79ra2lx+fr77/e9/fy6KGHk33XSTu/POO1O+mzFjhps1a5ZzjnWoAeDWr1+fnPZTXwcPHnQA3O7du5PLvPDCCy4nJ8e9884756zsUSLr0cuuXbscAPfWW28558Krx8i9MTl9+jSam5tRVVWV/C43NxdVVVVoamrKYsl6j/b2dgDA8OHDAQDNzc3o7OxMqdNx48ahtLSUdSrU1NTgpptuSqkrgHXo1x//+EdMnDgR3/jGNzBy5EhceeWV+NWvfpWcf/jwYcTj8ZR6jMViqKioYD3+z5e+9CU0NDTgzTffBADs378fr7zyCqZOnQqAdWjlp76ampowdOhQTJw4MblMVVUVcnNzsXPnznNe5t6ivb0dOTk5GDp0KIDw6jFyg/i99957OHPmDAoLC1O+LywsxF/+8pcslar36OrqQm1tLSZNmoTLLrsMABCPx5GXl5e8eD5WWFiIeDyehVJG09q1a7F3717s3r272zzWoT9/+9vfsHz5ctTV1eEHP/gBdu/eje9+97vIy8tDdXV1sq687m/W43/df//9SCQSGDduHPr164czZ87gkUcewaxZswCAdWjkp77i8ThGjhyZMr9///4YPnw46/QTnDp1Cvfddx9uv/325EB+YdVj5BomFExNTQ0OHDiAV155JdtF6VWOHj2K+fPnY8uWLRg4cGC2i9NrdXV1YeLEifjpT38KALjyyitx4MABrFixAtXV1VkuXe/whz/8Ac8++yzWrFmDL3zhC9i3bx9qa2tRXFzMOqRI6OzsxDe/+U0457B8+fLQtx+5rpyLLroI/fr16/Zrh9bWVhQVFWWpVL3DvHnzsHnzZrz88ssoKSlJfl9UVITTp0+jra0tZXnW6f81Nzfj+PHj+OIXv4j+/fujf//+aGxsxLJly9C/f38UFhayDn24+OKLcemll6Z8d8kll+DIkSMAkKwr3t+f7Hvf+x7uv/9+3HbbbRg/fjy+/e1vY8GCBaivrwfAOrTyU19FRUXdflzx73//G++//z7rVPi4UfLWW29hy5YtybclQHj1GLmGSV5eHiZMmICGhobkd11dXWhoaEBlZWUWSxZdzjnMmzcP69evx9atW1FWVpYyf8KECRgwYEBKnba0tODIkSOs0/+54YYb8Nprr2Hfvn3Jz8SJEzFr1qzk36xD3aRJk7r9VP3NN9/E6NGjAQBlZWUoKipKqcdEIoGdO3eyHv/nww8/RG5u6qO5X79+6OrqAsA6tPJTX5WVlWhra0Nzc3Nyma1bt6KrqwsVFRXnvMxR9XGj5NChQ/jTn/6EESNGpMwPrR4zCNbtcWvXrnX5+flu1apV7uDBg+6uu+5yQ4cOdfF4PNtFi6R77rnHxWIxt23bNvfuu+8mPx9++GFymbvvvtuVlpa6rVu3uj179rjKykpXWVmZxVJH39m/ynGOdejHrl27XP/+/d0jjzziDh065J599ll3wQUXuN/97nfJZR599FE3dOhQt3HjRvfnP//ZTZs2zZWVlbmPPvooiyWPjurqavepT33Kbd682R0+fNg9//zz7qKLLnLf//73k8uwDlOdOHHCvfrqq+7VV191ANwvfvEL9+qrryZ/LeKnvr72ta+5K6+80u3cudO98sorbuzYse7222/P1iFlRbp6PH36tLvllltcSUmJ27dvX8r/mo6OjuQ2wqjHSDZMnHPuiSeecKWlpS4vL89dffXVbseOHdkuUmQB8PysXLkyucxHH33kvvOd77hhw4a5Cy64wH3961937777bvYK3QvIhgnr0J9Nmza5yy67zOXn57tx48a5Z555JmV+V1eXW7RokSssLHT5+fnuhhtucC0tLVkqbfQkEgk3f/58V1pa6gYOHOg+/elPux/+8IcpD3/WYaqXX37Z8xlYXV3tnPNXX//85z/d7bff7gYPHuwKCgrcHXfc4U6cOJGFo8medPV4+PDhT/xf8/LLLye3EUY95jh3VjpBIiIioiyKXIwJERERnb/YMCEiIqLIYMOEiIiIIoMNEyIiIooMNkyIiIgoMtgwISIioshgw4SIiIgigw0TIiIiigw2TIiIiCgy2DAhIiKiyGDDhIiIiCKDDRMiIiKKjP8AhoB1uxYGu3IAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116, 0.0126],\n", " [0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126,\n", " 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.0126, 0.8116]],\n", " device='cuda:0')\n" ] } ], "source": [ "print(d3pm.q_mats[200])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1lklEQVR4nO3de3xU9Z3/8XcSkkmCTGKwSUgJmK2tJAhyq2TqpaghKU1dL+mFNsWsoq40tCbZBWR/ELmoESx3I5SKxD5KWnG3WAUKGYKAlHALRrlYtCs1tjiT3WIYAZkMyfz+6CNnGbkOJhm+w+v5eMzj4Zzzme/5fDIxvB/nzEki/H6/XwAAAAaJDHUDAAAAwSLAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACM0y3UDXSWtrY2HT58WD169FBERESo2wEAABfB7/fr008/VVpamiIjz32eJWwDzOHDh5Wenh7qNgAAwCX46KOP1Lt373PuD9sA06NHD0n/+ALY7fYOW9fn86mmpka5ubmKjo7usHUvZ1fazMwb3pg3vDGv+Twej9LT061/x88lbANM+2Uju93e4QEmPj5edrs9bL5ZLuRKm5l5wxvzhjfmDR8X+vgHH+IFAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAME63UDdgqhumrZe39fx/6rsj/eWZ/C47FgAAlzvOwAAAAOMEFWBaW1s1depUZWRkKC4uTl/5ylc0c+ZM+f1+q8bv96u8vFy9evVSXFyccnJy9P777wesc+TIERUWFsputysxMVFjx47VsWPHAmreeecd3XrrrYqNjVV6erpmz579BcYEAADhJKgAM2vWLC1evFjPPfec3n33Xc2aNUuzZ8/WokWLrJrZs2dr4cKFWrJkiXbs2KHu3bsrLy9PJ0+etGoKCwu1f/9+OZ1OrV69Wlu2bNEjjzxi7fd4PMrNzVXfvn1VX1+vZ599VtOmTdPSpUs7YGQAAGC6oD4Ds23bNt19993Kz//H5zGuvfZa/eY3v9HOnTsl/ePsy/z58zVlyhTdfffdkqRf/epXSklJ0auvvqrRo0fr3Xff1bp167Rr1y4NGzZMkrRo0SJ9+9vf1s9//nOlpaVpxYoVamlp0YsvvqiYmBj1799fDQ0Nmjt3bkDQAQAAV6agAsw3vvENLV26VO+9956+9rWv6e2339bWrVs1d+5cSdKhQ4fkcrmUk5NjvSYhIUHDhw9XXV2dRo8erbq6OiUmJlrhRZJycnIUGRmpHTt26N5771VdXZ1uu+02xcTEWDV5eXmaNWuWPvnkE1199dVn9Ob1euX1eq3nHo9HkuTz+eTz+YIZ87za17JF+i9Q2bE6coZLPXYoe+hKzBvemDe8Ma/5LnaWoALM448/Lo/Ho379+ikqKkqtra166qmnVFhYKElyuVySpJSUlIDXpaSkWPtcLpeSk5MDm+jWTUlJSQE1GRkZZ6zRvu9sAaaiokLTp08/Y3tNTY3i4+ODGfOizBzW1uFrns/atWu79Hhn43Q6Q91Cl2Le8Ma84Y15zXXixImLqgsqwKxcuVIrVqxQdXW1dVmnpKREaWlpKioquqRGO8rkyZNVVlZmPfd4PEpPT1dubq7sdnuHHcfn88npdGrq7kh527ruNup90/K67Fif1z7zyJEjFR0dHbI+ugrzhjfmDW/Ma772KygXElSAmTBhgh5//HGNHj1akjRgwAB9+OGHqqioUFFRkVJTUyVJbrdbvXr1sl7ndrs1aNAgSVJqaqqampoC1j116pSOHDlivT41NVVutzugpv15e83n2Ww22Wy2M7ZHR0d3ypvqbYvo0t8Dczl8Y3bW1/JyxbzhjXnDG/Oa62LnCOoupBMnTigyMvAlUVFRamv7x+WUjIwMpaamqra21trv8Xi0Y8cOORwOSZLD4VBzc7Pq6+utmo0bN6qtrU3Dhw+3arZs2RJwHczpdOr6668/6+UjAABwZQkqwNx111166qmntGbNGv3lL3/RqlWrNHfuXN17772SpIiICJWUlOjJJ5/Ua6+9pr179+r+++9XWlqa7rnnHklSZmamvvWtb+nhhx/Wzp079cc//lHjx4/X6NGjlZaWJkn60Y9+pJiYGI0dO1b79+/Xyy+/rAULFgRcIgIAAFeuoC4hLVq0SFOnTtVPfvITNTU1KS0tTf/6r/+q8vJyq2bixIk6fvy4HnnkETU3N+uWW27RunXrFBsba9WsWLFC48eP15133qnIyEgVFBRo4cKF1v6EhATV1NSouLhYQ4cO1TXXXKPy8nJuoQYAAJKCDDA9evTQ/PnzNX/+/HPWREREaMaMGZoxY8Y5a5KSklRdXX3eYw0cOFBvvvlmMO0BAIArBH8LCQAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwTlAB5tprr1VERMQZj+LiYknSyZMnVVxcrJ49e+qqq65SQUGB3G53wBqNjY3Kz89XfHy8kpOTNWHCBJ06dSqgZtOmTRoyZIhsNpuuu+46VVVVfbEpAQBAWAkqwOzatUsff/yx9XA6nZKk733ve5Kk0tJSvf7663rllVe0efNmHT58WPfdd5/1+tbWVuXn56ulpUXbtm3TSy+9pKqqKpWXl1s1hw4dUn5+vm6//XY1NDSopKREDz30kNavX98R8wIAgDDQLZjiL33pSwHPn3nmGX3lK1/RN7/5TR09elTLli1TdXW17rjjDknS8uXLlZmZqe3btys7O1s1NTU6cOCANmzYoJSUFA0aNEgzZ87UpEmTNG3aNMXExGjJkiXKyMjQnDlzJEmZmZnaunWr5s2bp7y8vA4aGwAAmCyoAHO6lpYW/frXv1ZZWZkiIiJUX18vn8+nnJwcq6Zfv37q06eP6urqlJ2drbq6Og0YMEApKSlWTV5ensaNG6f9+/dr8ODBqqurC1ijvaakpOS8/Xi9Xnm9Xuu5x+ORJPl8Pvl8vksd8wzta9ki/R22ZjDHDYX2Y4eyh67EvOGNecMb85rvYme55ADz6quvqrm5Wf/yL/8iSXK5XIqJiVFiYmJAXUpKilwul1Vzenhp39++73w1Ho9Hn332meLi4s7aT0VFhaZPn37G9pqaGsXHxwc934XMHNbW4Wuez9q1a7v0eGfTfsnwSsG84Y15wxvzmuvEiRMXVXfJAWbZsmUaNWqU0tLSLnWJDjV58mSVlZVZzz0ej9LT05Wbmyu73d5hx/H5fHI6nZq6O1LetogOW/dC9k0L3eWz9plHjhyp6OjokPXRVZg3vDFveGNe87VfQbmQSwowH374oTZs2KDf/e531rbU1FS1tLSoubk54CyM2+1WamqqVbNz586AtdrvUjq95vN3Lrndbtnt9nOefZEkm80mm812xvbo6OhOeVO9bRHytnZdgLkcvjE762t5uWLe8Ma84Y15zXWxc1zS74FZvny5kpOTlZ+fb20bOnSooqOjVVtba207ePCgGhsb5XA4JEkOh0N79+5VU1OTVeN0OmW325WVlWXVnL5Ge037GgAAAEEHmLa2Ni1fvlxFRUXq1u3/TuAkJCRo7NixKisr0xtvvKH6+no98MADcjgcys7OliTl5uYqKytLY8aM0dtvv63169drypQpKi4uts6ePProo/rggw80ceJE/elPf9Lzzz+vlStXqrS0tINGBgAApgv6EtKGDRvU2NioBx988Ix98+bNU2RkpAoKCuT1epWXl6fnn3/e2h8VFaXVq1dr3Lhxcjgc6t69u4qKijRjxgyrJiMjQ2vWrFFpaakWLFig3r1764UXXuAWagAAYAk6wOTm5srvP/stxLGxsaqsrFRlZeU5X9+3b98L3lEzYsQIvfXWW8G2BgAArhD8LSQAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDhBB5i//e1v+vGPf6yePXsqLi5OAwYM0O7du639fr9f5eXl6tWrl+Li4pSTk6P3338/YI0jR46osLBQdrtdiYmJGjt2rI4dOxZQ88477+jWW29VbGys0tPTNXv27EscEQAAhJugAswnn3yim2++WdHR0frDH/6gAwcOaM6cObr66qutmtmzZ2vhwoVasmSJduzYoe7duysvL08nT560agoLC7V//345nU6tXr1aW7Zs0SOPPGLt93g8ys3NVd++fVVfX69nn31W06ZN09KlSztgZAAAYLpuwRTPmjVL6enpWr58ubUtIyPD+m+/36/58+drypQpuvvuuyVJv/rVr5SSkqJXX31Vo0eP1rvvvqt169Zp165dGjZsmCRp0aJF+va3v62f//znSktL04oVK9TS0qIXX3xRMTEx6t+/vxoaGjR37tyAoAMAAK5MQQWY1157TXl5efre976nzZs368tf/rJ+8pOf6OGHH5YkHTp0SC6XSzk5OdZrEhISNHz4cNXV1Wn06NGqq6tTYmKiFV4kKScnR5GRkdqxY4fuvfde1dXV6bbbblNMTIxVk5eXp1mzZumTTz4JOOPTzuv1yuv1Ws89Ho8kyefzyefzBTPmebWvZYv0d9iawRw3FNqPHcoeuhLzhjfmDW/Ma76LnSWoAPPBBx9o8eLFKisr03/8x39o165d+tnPfqaYmBgVFRXJ5XJJklJSUgJel5KSYu1zuVxKTk4ObKJbNyUlJQXUnH5m5/Q1XS7XWQNMRUWFpk+ffsb2mpoaxcfHBzPmRZk5rK3D1zyftWvXdunxzsbpdIa6hS7FvOGNecMb85rrxIkTF1UXVIBpa2vTsGHD9PTTT0uSBg8erH379mnJkiUqKioKvssONHnyZJWVlVnPPR6P0tPTlZubK7vd3mHH8fl8cjqdmro7Ut62iA5b90L2TcvrsmN9XvvMI0eOVHR0dMj66CrMG96YN7wxr/nar6BcSFABplevXsrKygrYlpmZqf/6r/+SJKWmpkqS3G63evXqZdW43W4NGjTIqmlqagpY49SpUzpy5Ij1+tTUVLnd7oCa9uftNZ9ns9lks9nO2B4dHd0pb6q3LULe1q4LMJfDN2ZnfS0vV8wb3pg3vDGvuS52jqDuQrr55pt18ODBgG3vvfee+vbtK+kfH+hNTU1VbW2ttd/j8WjHjh1yOBySJIfDoebmZtXX11s1GzduVFtbm4YPH27VbNmyJeA6mNPp1PXXX3/Wy0cAAODKElSAKS0t1fbt2/X000/rz3/+s6qrq7V06VIVFxdLkiIiIlRSUqInn3xSr732mvbu3av7779faWlpuueeeyT944zNt771LT388MPauXOn/vjHP2r8+PEaPXq00tLSJEk/+tGPFBMTo7Fjx2r//v16+eWXtWDBgoBLRAAA4MoV1CWkr3/961q1apUmT56sGTNmKCMjQ/Pnz1dhYaFVM3HiRB0/flyPPPKImpubdcstt2jdunWKjY21alasWKHx48frzjvvVGRkpAoKCrRw4UJrf0JCgmpqalRcXKyhQ4fqmmuuUXl5ObdQAwAASUEGGEn6zne+o+985zvn3B8REaEZM2ZoxowZ56xJSkpSdXX1eY8zcOBAvfnmm8G2BwAArgD8LSQAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYJygAsy0adMUERER8OjXr5+1/+TJkyouLlbPnj111VVXqaCgQG63O2CNxsZG5efnKz4+XsnJyZowYYJOnToVULNp0yYNGTJENptN1113naqqqi59QgAAEHaCPgPTv39/ffzxx9Zj69at1r7S0lK9/vrreuWVV7R582YdPnxY9913n7W/tbVV+fn5amlp0bZt2/TSSy+pqqpK5eXlVs2hQ4eUn5+v22+/XQ0NDSopKdFDDz2k9evXf8FRAQBAuOgW9Au6dVNqauoZ248ePaply5apurpad9xxhyRp+fLlyszM1Pbt25Wdna2amhodOHBAGzZsUEpKigYNGqSZM2dq0qRJmjZtmmJiYrRkyRJlZGRozpw5kqTMzExt3bpV8+bNU15e3hccFwAAhIOgA8z777+vtLQ0xcbGyuFwqKKiQn369FF9fb18Pp9ycnKs2n79+qlPnz6qq6tTdna26urqNGDAAKWkpFg1eXl5GjdunPbv36/Bgwerrq4uYI32mpKSkvP25fV65fV6recej0eS5PP55PP5gh3znNrXskX6O2zNYI4bCu3HDmUPXYl5wxvzhjfmNd/FzhJUgBk+fLiqqqp0/fXX6+OPP9b06dN16623at++fXK5XIqJiVFiYmLAa1JSUuRyuSRJLpcrILy072/fd74aj8ejzz77THFxcWftraKiQtOnTz9je01NjeLj44MZ86LMHNbW4Wuez9q1a7v0eGfjdDpD3UKXYt7wxrzhjXnNdeLEiYuqCyrAjBo1yvrvgQMHavjw4erbt69Wrlx5zmDRVSZPnqyysjLrucfjUXp6unJzc2W32zvsOD6fT06nU1N3R8rbFtFh617Ivmmhu3zWPvPIkSMVHR0dsj66CvOGN+YNb8xrvvYrKBcS9CWk0yUmJuprX/ua/vznP2vkyJFqaWlRc3NzwFkYt9ttfWYmNTVVO3fuDFij/S6l02s+f+eS2+2W3W4/b0iy2Wyy2WxnbI+Oju6UN9XbFiFva9cFmMvhG7OzvpaXK+YNb8wb3pjXXBc7xxf6PTDHjh3Tf//3f6tXr14aOnSooqOjVVtba+0/ePCgGhsb5XA4JEkOh0N79+5VU1OTVeN0OmW325WVlWXVnL5Ge037GgAAAEEFmH//93/X5s2b9Ze//EXbtm3Tvffeq6ioKP3whz9UQkKCxo4dq7KyMr3xxhuqr6/XAw88IIfDoezsbElSbm6usrKyNGbMGL399ttav369pkyZouLiYuvsyaOPPqoPPvhAEydO1J/+9Cc9//zzWrlypUpLSzt+egAAYKSgLiH99a9/1Q9/+EP9/e9/15e+9CXdcsst2r59u770pS9JkubNm6fIyEgVFBTI6/UqLy9Pzz//vPX6qKgorV69WuPGjZPD4VD37t1VVFSkGTNmWDUZGRlas2aNSktLtWDBAvXu3VsvvPACt1ADAABLUAHmt7/97Xn3x8bGqrKyUpWVlees6du37wXvqBkxYoTeeuutYFoDAABXEP4WEgAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgnC8UYJ555hlFRESopKTE2nby5EkVFxerZ8+euuqqq1RQUCC32x3wusbGRuXn5ys+Pl7JycmaMGGCTp06FVCzadMmDRkyRDabTdddd52qqqq+SKsAACCMXHKA2bVrl37xi19o4MCBAdtLS0v1+uuv65VXXtHmzZt1+PBh3Xfffdb+1tZW5efnq6WlRdu2bdNLL72kqqoqlZeXWzWHDh1Sfn6+br/9djU0NKikpEQPPfSQ1q9ff6ntAgCAMHJJAebYsWMqLCzUL3/5S1199dXW9qNHj2rZsmWaO3eu7rjjDg0dOlTLly/Xtm3btH37dklSTU2NDhw4oF//+tcaNGiQRo0apZkzZ6qyslItLS2SpCVLligjI0Nz5sxRZmamxo8fr+9+97uaN29eB4wMAABM1+1SXlRcXKz8/Hzl5OToySeftLbX19fL5/MpJyfH2tavXz/16dNHdXV1ys7OVl1dnQYMGKCUlBSrJi8vT+PGjdP+/fs1ePBg1dXVBazRXnP6parP83q98nq91nOPxyNJ8vl88vl8lzLmWbWvZYv0d9iawRw3FNqPHcoeuhLzhjfmDW/Ma76LnSXoAPPb3/5We/bs0a5du87Y53K5FBMTo8TExIDtKSkpcrlcVs3p4aV9f/u+89V4PB599tlniouLO+PYFRUVmj59+hnba2pqFB8ff/EDXqSZw9o6fM3zWbt2bZce72ycTmeoW+hSzBvemDe8Ma+5Tpw4cVF1QQWYjz76SI899picTqdiY2MvqbHOMnnyZJWVlVnPPR6P0tPTlZubK7vd3mHH8fl8cjqdmro7Ut62iA5b90L2TcvrsmN9XvvMI0eOVHR0dMj66CrMG96YN7wxr/nar6BcSFABpr6+Xk1NTRoyZIi1rbW1VVu2bNFzzz2n9evXq6WlRc3NzQFnYdxut1JTUyVJqamp2rlzZ8C67XcpnV7z+TuX3G637Hb7Wc++SJLNZpPNZjtje3R0dKe8qd62CHlbuy7AXA7fmJ31tbxcMW94Y97wxrzmutg5gvoQ75133qm9e/eqoaHBegwbNkyFhYXWf0dHR6u2ttZ6zcGDB9XY2CiHwyFJcjgc2rt3r5qamqwap9Mpu92urKwsq+b0Ndpr2tcAAABXtqDOwPTo0UM33HBDwLbu3burZ8+e1vaxY8eqrKxMSUlJstvt+ulPfyqHw6Hs7GxJUm5urrKysjRmzBjNnj1bLpdLU6ZMUXFxsXUG5dFHH9Vzzz2niRMn6sEHH9TGjRu1cuVKrVmzpiNmBgAAhruku5DOZ968eYqMjFRBQYG8Xq/y8vL0/PPPW/ujoqK0evVqjRs3Tg6HQ927d1dRUZFmzJhh1WRkZGjNmjUqLS3VggUL1Lt3b73wwgvKywvd50AAAMDl4wsHmE2bNgU8j42NVWVlpSorK8/5mr59+17wrpoRI0borbfe+qLtAQCAMMTfQgIAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjBNUgFm8eLEGDhwou90uu90uh8OhP/zhD9b+kydPqri4WD179tRVV12lgoICud3ugDUaGxuVn5+v+Ph4JScna8KECTp16lRAzaZNmzRkyBDZbDZdd911qqqquvQJAQBA2AkqwPTu3VvPPPOM6uvrtXv3bt1xxx26++67tX//fklSaWmpXn/9db3yyivavHmzDh8+rPvuu896fWtrq/Lz89XS0qJt27bppZdeUlVVlcrLy62aQ4cOKT8/X7fffrsaGhpUUlKihx56SOvXr++gkQEAgOm6BVN81113BTx/6qmntHjxYm3fvl29e/fWsmXLVF1drTvuuEOStHz5cmVmZmr79u3Kzs5WTU2NDhw4oA0bNiglJUWDBg3SzJkzNWnSJE2bNk0xMTFasmSJMjIyNGfOHElSZmamtm7dqnnz5ikvL6+DxgYAACYLKsCcrrW1Va+88oqOHz8uh8Oh+vp6+Xw+5eTkWDX9+vVTnz59VFdXp+zsbNXV1WnAgAFKSUmxavLy8jRu3Djt379fgwcPVl1dXcAa7TUlJSXn7cfr9crr9VrPPR6PJMnn88nn813qmGdoX8sW6e+wNYM5bii0HzuUPXQl5g1vzBvemNd8FztL0AFm7969cjgcOnnypK666iqtWrVKWVlZamhoUExMjBITEwPqU1JS5HK5JEkulysgvLTvb993vhqPx6PPPvtMcXFxZ+2roqJC06dPP2N7TU2N4uPjgx3zgmYOa+vwNc9n7dq1XXq8s3E6naFuoUsxb3hj3vDGvOY6ceLERdUFHWCuv/56NTQ06OjRo/rP//xPFRUVafPmzUE32NEmT56ssrIy67nH41F6erpyc3Nlt9s77Dg+n09Op1NTd0fK2xbRYeteyL5pobt81j7zyJEjFR0dHbI+ugrzhjfmDW/Ma772KygXEnSAiYmJ0XXXXSdJGjp0qHbt2qUFCxboBz/4gVpaWtTc3BxwFsbtdis1NVWSlJqaqp07dwas136X0uk1n79zye12y263n/PsiyTZbDbZbLYztkdHR3fKm+pti5C3tesCzOXwjdlZX8vLFfOGN+YNb8xrroud4wv/Hpi2tjZ5vV4NHTpU0dHRqq2ttfYdPHhQjY2NcjgckiSHw6G9e/eqqanJqnE6nbLb7crKyrJqTl+jvaZ9DQAAgKDOwEyePFmjRo1Snz599Omnn6q6ulqbNm3S+vXrlZCQoLFjx6qsrExJSUmy2+366U9/KofDoezsbElSbm6usrKyNGbMGM2ePVsul0tTpkxRcXGxdfbk0Ucf1XPPPaeJEyfqwQcf1MaNG7Vy5UqtWbOm46cHAABGCirANDU16f7779fHH3+shIQEDRw4UOvXr9fIkSMlSfPmzVNkZKQKCgrk9XqVl5en559/3np9VFSUVq9erXHjxsnhcKh79+4qKirSjBkzrJqMjAytWbNGpaWlWrBggXr37q0XXniBW6gBAIAlqACzbNmy8+6PjY1VZWWlKisrz1nTt2/fC95RM2LECL311lvBtAYAAK4g/C0kAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGCcoAJMRUWFvv71r6tHjx5KTk7WPffco4MHDwbUnDx5UsXFxerZs6euuuoqFRQUyO12B9Q0NjYqPz9f8fHxSk5O1oQJE3Tq1KmAmk2bNmnIkCGy2Wy67rrrVFVVdWkTAgCAsBNUgNm8ebOKi4u1fft2OZ1O+Xw+5ebm6vjx41ZNaWmpXn/9db3yyivavHmzDh8+rPvuu8/a39raqvz8fLW0tGjbtm166aWXVFVVpfLycqvm0KFDys/P1+23366GhgaVlJTooYce0vr16ztgZAAAYLpuwRSvW7cu4HlVVZWSk5NVX1+v2267TUePHtWyZctUXV2tO+64Q5K0fPlyZWZmavv27crOzlZNTY0OHDigDRs2KCUlRYMGDdLMmTM1adIkTZs2TTExMVqyZIkyMjI0Z84cSVJmZqa2bt2qefPmKS8vr4NGBwAApgoqwHze0aNHJUlJSUmSpPr6evl8PuXk5Fg1/fr1U58+fVRXV6fs7GzV1dVpwIABSklJsWry8vI0btw47d+/X4MHD1ZdXV3AGu01JSUl5+zF6/XK6/Vazz0ejyTJ5/PJ5/N9kTEDtK9li/R32JrBHDcU2o8dyh66EvOGN+YNb8xrvoud5ZIDTFtbm0pKSnTzzTfrhhtukCS5XC7FxMQoMTExoDYlJUUul8uqOT28tO9v33e+Go/Ho88++0xxcXFn9FNRUaHp06efsb2mpkbx8fGXNuR5zBzW1uFrns/atWu79Hhn43Q6Q91Cl2Le8Ma84Y15zXXixImLqrvkAFNcXKx9+/Zp69atl7pEh5o8ebLKysqs5x6PR+np6crNzZXdbu+w4/h8PjmdTk3dHSlvW0SHrXsh+6aF7tJZ+8wjR45UdHR0yProKswb3pg3vDGv+dqvoFzIJQWY8ePHa/Xq1dqyZYt69+5tbU9NTVVLS4uam5sDzsK43W6lpqZaNTt37gxYr/0updNrPn/nktvtlt1uP+vZF0my2Wyy2WxnbI+Oju6UN9XbFiFva9cFmMvhG7OzvpaXK+YNb8wb3pjXXBc7R1B3Ifn9fo0fP16rVq3Sxo0blZGREbB/6NChio6OVm1trbXt4MGDamxslMPhkCQ5HA7t3btXTU1NVo3T6ZTdbldWVpZVc/oa7TXtawAAgCtbUGdgiouLVV1drd///vfq0aOH9ZmVhIQExcXFKSEhQWPHjlVZWZmSkpJkt9v105/+VA6HQ9nZ2ZKk3NxcZWVlacyYMZo9e7ZcLpemTJmi4uJi6wzKo48+queee04TJ07Ugw8+qI0bN2rlypVas2ZNB48PAABMFNQZmMWLF+vo0aMaMWKEevXqZT1efvllq2bevHn6zne+o4KCAt12221KTU3V7373O2t/VFSUVq9eraioKDkcDv34xz/W/fffrxkzZlg1GRkZWrNmjZxOp2688UbNmTNHL7zwArdQAwAASUGegfH7L3zrcGxsrCorK1VZWXnOmr59+17wrpoRI0borbfeCqY9AABwheBvIQEAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxgk6wGzZskV33XWX0tLSFBERoVdffTVgv9/vV3l5uXr16qW4uDjl5OTo/fffD6g5cuSICgsLZbfblZiYqLFjx+rYsWMBNe+8845uvfVWxcbGKj09XbNnzw5+OgAAEJaCDjDHjx/XjTfeqMrKyrPunz17thYuXKglS5Zox44d6t69u/Ly8nTy5EmrprCwUPv375fT6dTq1au1ZcsWPfLII9Z+j8ej3Nxc9e3bV/X19Xr22Wc1bdo0LV269BJGBAAA4aZbsC8YNWqURo0addZ9fr9f8+fP15QpU3T33XdLkn71q18pJSVFr776qkaPHq13331X69at065duzRs2DBJ0qJFi/Ttb39bP//5z5WWlqYVK1aopaVFL774omJiYtS/f381NDRo7ty5AUEHAABcmYIOMOdz6NAhuVwu5eTkWNsSEhI0fPhw1dXVafTo0aqrq1NiYqIVXiQpJydHkZGR2rFjh+69917V1dXptttuU0xMjFWTl5enWbNm6ZNPPtHVV199xrG9Xq+8Xq/13OPxSJJ8Pp98Pl+Hzdi+li3S32FrBnPcUGg/dih76ErMG96YN7wxr/kudpYODTAul0uSlJKSErA9JSXF2udyuZScnBzYRLduSkpKCqjJyMg4Y432fWcLMBUVFZo+ffoZ22tqahQfH3+JE53bzGFtHb7m+axdu7ZLj3c2Tqcz1C10KeYNb8wb3pjXXCdOnLioug4NMKE0efJklZWVWc89Ho/S09OVm5sru93eYcfx+XxyOp2aujtS3raIDlv3QvZNy+uyY31e+8wjR45UdHR0yProKswb3pg3vDGv+dqvoFxIhwaY1NRUSZLb7VavXr2s7W63W4MGDbJqmpqaAl536tQpHTlyxHp9amqq3G53QE378/aaz7PZbLLZbGdsj46O7pQ31dsWIW9r1wWYy+Ebs7O+lpcr5g1vzBvemNdcFztHh/4emIyMDKWmpqq2ttba5vF4tGPHDjkcDkmSw+FQc3Oz6uvrrZqNGzeqra1Nw4cPt2q2bNkScB3M6XTq+uuvP+vlIwAAcGUJOsAcO3ZMDQ0NamhokPSPD+42NDSosbFRERERKikp0ZNPPqnXXntNe/fu1f3336+0tDTdc889kqTMzEx961vf0sMPP6ydO3fqj3/8o8aPH6/Ro0crLS1NkvSjH/1IMTExGjt2rPbv36+XX35ZCxYsCLhEBAAArlxBX0LavXu3br/9dut5e6goKipSVVWVJk6cqOPHj+uRRx5Rc3OzbrnlFq1bt06xsbHWa1asWKHx48frzjvvVGRkpAoKCrRw4UJrf0JCgmpqalRcXKyhQ4fqmmuuUXl5ObdQAwAASZcQYEaMGCG//9y3EEdERGjGjBmaMWPGOWuSkpJUXV193uMMHDhQb775ZrDtAQCAKwB/CwkAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDjdQt0AAJjg2sfXdOnx/vJMfpceLxQ642tqi/Jr9k3SDdPWy9saEbAv3L+mV9r3KAHGEF39jXm68/1ACEfhMG8ofrCE8ns0GOHw/nYGU96/L+JK+wc+3F3WAaayslLPPvusXC6XbrzxRi1atEg33XRTqNsCLnvB/KDmH/TLU0f9Y8v7GzpdEZhOf3+lK+v9vWw/A/Pyyy+rrKxMTzzxhPbs2aMbb7xReXl5ampqCnVrAAAgxC7bADN37lw9/PDDeuCBB5SVlaUlS5YoPj5eL774YqhbAwAAIXZZXkJqaWlRfX29Jk+ebG2LjIxUTk6O6urqzvoar9crr9drPT969Kgk6ciRI/L5fB3Wm8/n04kTJ9TNF6nWtivjdF23Nr9OnGi7YmZm3vDGvOGNebvO3//+905Z99NPP5Uk+f3+8xf6L0N/+9vf/JL827ZtC9g+YcIE/0033XTW1zzxxBN+STx48ODBgwePMHh89NFH580Kl+UZmEsxefJklZWVWc/b2tp05MgR9ezZUxERHZdKPR6P0tPT9dFHH8lut3fYupezK21m5g1vzBvemNd8fr9fn376qdLS0s5bd1kGmGuuuUZRUVFyu90B291ut1JTU8/6GpvNJpvNFrAtMTGxs1qU3W4Pm2+Wi3Wlzcy84Y15wxvzmi0hIeGCNZflh3hjYmI0dOhQ1dbWWtva2tpUW1srh8MRws4AAMDl4LI8AyNJZWVlKioq0rBhw3TTTTdp/vz5On78uB544IFQtwYAAELssg0wP/jBD/Q///M/Ki8vl8vl0qBBg7Ru3TqlpKSEtC+bzaYnnnjijMtV4exKm5l5wxvzhjfmvXJE+P0Xuk8JAADg8nJZfgYGAADgfAgwAADAOAQYAABgHAIMAAAwDgEmSJWVlbr22msVGxur4cOHa+fOnaFuqVNUVFTo61//unr06KHk5GTdc889OnjwYKjb6jLPPPOMIiIiVFJSEupWOs3f/vY3/fjHP1bPnj0VFxenAQMGaPfu3aFuq1O0trZq6tSpysjIUFxcnL7yla9o5syZF/5bKwbZsmWL7rrrLqWlpSkiIkKvvvpqwH6/36/y8nL16tVLcXFxysnJ0fvvvx+aZjvA+eb1+XyaNGmSBgwYoO7duystLU3333+/Dh8+HLqGv6ALvb+ne/TRRxUREaH58+d3WX+hQIAJwssvv6yysjI98cQT2rNnj2688Ubl5eWpqakp1K11uM2bN6u4uFjbt2+X0+mUz+dTbm6ujh8/HurWOt2uXbv0i1/8QgMHDgx1K53mk08+0c0336zo6Gj94Q9/0IEDBzRnzhxdffXVoW6tU8yaNUuLFy/Wc889p3fffVezZs3S7NmztWjRolC31mGOHz+uG2+8UZWVlWfdP3v2bC1cuFBLlizRjh071L17d+Xl5enkyZNd3GnHON+8J06c0J49ezR16lTt2bNHv/vd73Tw4EH98z//cwg67RgXen/brVq1Stu3b7/gr+EPCx3xxxevFDfddJO/uLjYet7a2upPS0vzV1RUhLCrrtHU1OSX5N+8eXOoW+lUn376qf+rX/2q3+l0+r/5zW/6H3vssVC31CkmTZrkv+WWW0LdRpfJz8/3P/jggwHb7rvvPn9hYWGIOupckvyrVq2ynre1tflTU1P9zz77rLWtubnZb7PZ/L/5zW9C0GHH+vy8Z7Nz506/JP+HH37YNU11onPN+9e//tX/5S9/2b9v3z5/3759/fPmzevy3roSZ2AuUktLi+rr65WTk2Nti4yMVE5Ojurq6kLYWdc4evSoJCkpKSnEnXSu4uJi5efnB7zP4ei1117TsGHD9L3vfU/JyckaPHiwfvnLX4a6rU7zjW98Q7W1tXrvvfckSW+//ba2bt2qUaNGhbizrnHo0CG5XK6A7+uEhAQNHz78ivj5Jf3jZ1hERESn/o28UGpra9OYMWM0YcIE9e/fP9TtdInL9jfxXm7+93//V62trWf8JuCUlBT96U9/ClFXXaOtrU0lJSW6+eabdcMNN4S6nU7z29/+Vnv27NGuXbtC3Uqn++CDD7R48WKVlZXpP/7jP7Rr1y797Gc/U0xMjIqKikLdXod7/PHH5fF41K9fP0VFRam1tVVPPfWUCgsLQ91al3C5XJJ01p9f7fvC2cmTJzVp0iT98Ic/DKs/eHi6WbNmqVu3bvrZz34W6la6DAEGF1RcXKx9+/Zp69atoW6l03z00Ud67LHH5HQ6FRsbG+p2Ol1bW5uGDRump59+WpI0ePBg7du3T0uWLAnLALNy5UqtWLFC1dXV6t+/vxoaGlRSUqK0tLSwnBf/x+fz6fvf/778fr8WL14c6nY6RX19vRYsWKA9e/YoIiIi1O10GS4hXaRrrrlGUVFRcrvdAdvdbrdSU1ND1FXnGz9+vFavXq033nhDvXv3DnU7naa+vl5NTU0aMmSIunXrpm7dumnz5s1auHChunXrptbW1lC32KF69eqlrKysgG2ZmZlqbGwMUUeda8KECXr88cc1evRoDRgwQGPGjFFpaakqKipC3VqXaP8ZdaX9/GoPLx9++KGcTmfYnn1588031dTUpD59+lg/vz788EP927/9m6699tpQt9dpCDAXKSYmRkOHDlVtba21ra2tTbW1tXI4HCHsrHP4/X6NHz9eq1at0saNG5WRkRHqljrVnXfeqb1796qhocF6DBs2TIWFhWpoaFBUVFSoW+xQN9988xm3xb/33nvq27dviDrqXCdOnFBkZOCPu6ioKLW1tYWoo66VkZGh1NTUgJ9fHo9HO3bsCMufX9L/hZf3339fGzZsUM+ePUPdUqcZM2aM3nnnnYCfX2lpaZowYYLWr18f6vY6DZeQglBWVqaioiINGzZMN910k+bPn6/jx4/rgQceCHVrHa64uFjV1dX6/e9/rx49eljXyRMSEhQXFxfi7jpejx49zvh8T/fu3dWzZ8+w/NxPaWmpvvGNb+jpp5/W97//fe3cuVNLly7V0qVLQ91ap7jrrrv01FNPqU+fPurfv7/eeustzZ07Vw8++GCoW+swx44d05///Gfr+aFDh9TQ0KCkpCT16dNHJSUlevLJJ/XVr35VGRkZmjp1qtLS0nTPPfeErukv4Hzz9urVS9/97ne1Z88erV69Wq2trdbPsKSkJMXExISq7Ut2off38wEtOjpaqampuv7667u61a4T6tugTLNo0SJ/nz59/DExMf6bbrrJv3379lC31CkknfWxfPnyULfWZcL5Nmq/3+9//fXX/TfccIPfZrP5+/Xr51+6dGmoW+o0Ho/H/9hjj/n79Onjj42N9f/TP/2T///9v//n93q9oW6tw7zxxhtn/X+2qKjI7/f/41bqqVOn+lNSUvw2m81/5513+g8ePBjapr+A88176NChc/4Me+ONN0Ld+iW50Pv7eVfCbdQRfn8Y/SpKAABwReAzMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAY5/8Dryo8NpaTnRMAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "# THIS CHECKS IF Q SAMPLE REALLY SAMPLES from Q.\n", "\n", "x = torch.tensor([[0]])\n", "collects = []\n", "for idx in range(10000):\n", " sample = d3pm.q_sample(x, torch.tensor([200], device = 'cuda:0'), torch.rand((*x.shape, d3pm.num_classses), device='cuda:0'))\n", " collects.append(sample.item())\n", "\n", "# plot histogram\n", "import pandas as pd\n", "\n", "df = pd.DataFrame(collects, columns=['sample'])\n", "df['sample'].hist(bins=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "qmats = torch.arange(0, 1000).reshape(10, 10, 10)\n", "image_as_pixels = torch.randint(0, 10, (1, 1, 1))\n", "print(image_as_pixels)\n", "qmats[image_as_pixels].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "qmats[1, 0, 1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "cu122py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 2 }