Repository: ptrblck/pytorch_misc Branch: master Commit: 31ac50c415f1 Files: 22 Total size: 46.6 KB Directory structure: gitextract_qllenf3o/ ├── LocallyConnected2d.py ├── README.md ├── accumulate_gradients.py ├── adaptive_batchnorm.py ├── adaptive_pooling_torchvision.py ├── batch_norm_manual.py ├── change_crop_in_dataset.py ├── channel_to_patches.py ├── conv_rnn.py ├── csv_chunk_read.py ├── densenet_forwardhook.py ├── edge_weighting_segmentation.py ├── image_rotation_with_matrix.py ├── mnist_autoencoder.py ├── mnist_permuted.py ├── model_sharding_data_parallel.py ├── momentum_update_nograd.py ├── pytorch_redis.py ├── shared_array.py ├── shared_dict.py ├── unet_demo.py └── weighted_sampling.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LocallyConnected2d.py ================================================ """ Test implementation of locally connected 2d layer The first part of the script was used for debugging @author: ptrblck """ import torch import torch.nn as nn from torch.nn.modules.utils import _pair ## DEBUG batch_size = 5 in_channels = 3 h, w = 24, 24 x = torch.ones(batch_size, in_channels, h, w) kh, kw = 3, 3 # kernel_size dh, dw = 1, 1 # stride x_windows = x.unfold(2, kh, dh).unfold(3, kw, dw) x_windows = x_windows.contiguous().view(*x_windows.size()[:-2], -1) out_channels = 2 weights = torch.randn(1, out_channels, in_channels, *x_windows.size()[2:]) output = (x_windows.unsqueeze(1) * weights).sum([2, -1]) ## DEBUG class LocallyConnected2d(nn.Module): def __init__(self, in_channels, out_channels, output_size, kernel_size, stride, bias=False): super(LocallyConnected2d, self).__init__() output_size = _pair(output_size) self.weight = nn.Parameter( torch.randn(1, out_channels, in_channels, output_size[0], output_size[1], kernel_size**2) ) if bias: self.bias = nn.Parameter( torch.randn(1, out_channels, output_size[0], output_size[1]) ) else: self.register_parameter('bias', None) self.kernel_size = _pair(kernel_size) self.stride = _pair(stride) def forward(self, x): _, c, h, w = x.size() kh, kw = self.kernel_size dh, dw = self.stride x = x.unfold(2, kh, dh).unfold(3, kw, dw) x = x.contiguous().view(*x.size()[:-2], -1) # Sum in in_channel and kernel_size dims out = (x.unsqueeze(1) * self.weight).sum([2, -1]) if self.bias is not None: out += self.bias return out # Create input batch_size = 5 in_channels = 3 h, w = 24, 24 x = torch.randn(batch_size, in_channels, h, w) # Create layer and test if backpropagation works out_channels = 2 output_size = 22 kernel_size = 3 stride = 1 conv = LocallyConnected2d( in_channels, out_channels, output_size, kernel_size, stride, bias=True) out = conv(x) out.mean().backward() print(conv.weight.grad) ================================================ FILE: README.md ================================================ # PyTorch misc Collection of code snippets I've written for the [PyTorch discussion board](https://discuss.pytorch.org/). All scripts were testes using the PyTorch 1.0 preview and torchvision `0.2.1`. Additional libraries, e.g. `numpy` or `pandas`, are used in a few scripts. Some scripts might be a good starter to create a tutorial. ## Overview * [accumulate_gradients](https://github.com/ptrblck/pytorch_misc/blob/master/accumulate_gradients.py) - Comparison of accumulated gradients/losses to vanilla batch update. * [adaptive_batchnorm](https://github.com/ptrblck/pytorch_misc/blob/master/adaptive_batchnorm.py)- Adaptive BN implementation using two additional parameters: `out = a * x + b * bn(x)`. * [adaptive_pooling_torchvision](https://github.com/ptrblck/pytorch_misc/blob/master/adaptive_pooling_torchvision.py) - Example of using adaptive pooling layers in pretrained models to use different spatial input shapes. * [batch_norm_manual](https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py) - Comparison of PyTorch BatchNorm layers and a manual calculation. * [change_crop_in_dataset](https://github.com/ptrblck/pytorch_misc/blob/master/change_crop_in_dataset.py) - Change the image crop size on the fly using a Dataset. * [channel_to_patches](https://github.com/ptrblck/pytorch_misc/blob/master/channel_to_patches.py) - Permute image data so that channel values of each pixel are flattened to an image patch around the pixel. * [conv_rnn](https://github.com/ptrblck/pytorch_misc/blob/master/conv_rnn.py) - Combines a 3DCNN with an RNN; uses windowed frames as inputs. * [csv_chunk_read](https://github.com/ptrblck/pytorch_misc/blob/master/csv_chunk_read.py) - Provide data chunks from continuous .csv file. * [densenet_forwardhook](https://github.com/ptrblck/pytorch_misc/blob/master/densenet_forwardhook.py) - Use forward hooks to get intermediate activations from `densenet121`. Uses separate modules to process these activations further. * [edge_weighting_segmentation](https://github.com/ptrblck/pytorch_misc/blob/master/edge_weighting_segmentation.py) - Apply weighting to edges for a segmentation task. * [image_rotation_with_matrix](https://github.com/ptrblck/pytorch_misc/blob/master/image_rotation_with_matrix.py) - Rotate an image given an angle using 1.) a nested loop and 2.) a rotation matrix and mesh grid. * [LocallyConnected2d](https://github.com/ptrblck/pytorch_misc/blob/master/LocallyConnected2d.py) - Implementation of a locally connected 2d layer. * [mnist_autoencoder](https://github.com/ptrblck/pytorch_misc/blob/master/mnist_autoencoder.py) - Simple autoencoder for MNIST data. Includes visualizations of output images, intermediate activations and conv kernels. * [mnist_permuted](https://github.com/ptrblck/pytorch_misc/blob/master/mnist_permuted.py) - MNIST training using permuted pixel locations. * [model_sharding_data_parallel](https://github.com/ptrblck/pytorch_misc/blob/master/model_sharding_data_parallel.py) - Model sharding with `DataParallel` using 2 pairs of 2 GPUs. * [momentum_update_nograd](https://github.com/ptrblck/pytorch_misc/blob/master/momentum_update_nograd.py) - Script to see how parameters are updated when an optimizer is used with momentum/running estimates, even if gradients are zero. * [pytorch_redis](https://github.com/ptrblck/pytorch_misc/blob/master/pytorch_redis.py) - Script to demonstrate the loading data from redis using a PyTorch Dataset and DataLoader. * [shared_array](https://github.com/ptrblck/pytorch_misc/blob/master/shared_array.py) - Script to demonstrate the usage of shared arrays using multiple workers. * [shared_dict](https://github.com/ptrblck/pytorch_misc/blob/master/shared_dict.py) - Script to demonstrate the usage of shared dicts using multiple workers. * [unet_demo](https://github.com/ptrblck/pytorch_misc/blob/master/unet_demo.py) - Simple UNet demo. * [weighted_sampling](https://github.com/ptrblck/pytorch_misc/blob/master/weighted_sampling.py) - Usage of WeightedRandomSampler using an imbalanced dataset with class imbalance 99 to 1. Feedback is very welcome! ================================================ FILE: accumulate_gradients.py ================================================ """ Comparison of accumulated gradients/losses to vanilla batch update. Comments from @albanD: https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20 @author: ptrblck """ import torch import torch.nn as nn # Accumulate loss for each samples # more runtime, more memory x1 = torch.ones(2, 1) w1 = torch.ones(1, 1, requires_grad=True) y1 = torch.ones(2, 1) * 2 criterion = nn.MSELoss() loss1 = 0 for i in range(10): output1 = torch.matmul(x1, w1) loss1 += criterion(output1, y1) loss1 /= 10 # scale loss to match batch gradient loss1.backward() print('Accumulated losses: {}'.format(w1.grad)) # Use whole batch to calculate gradient # least runtime, more memory x2 = torch.ones(20, 1) w2 = torch.ones(1, 1, requires_grad=True) y2 = torch.ones(20, 1) * 2 output2 = torch.matmul(x2, w2) loss2 = criterion(output2, y2) loss2.backward() print('Batch gradient: {}'.format(w2.grad)) # Accumulate scaled gradient # more runtime, least memory x3 = torch.ones(2, 1) w3 = torch.ones(1, 1, requires_grad=True) y3 = torch.ones(2, 1) * 2 for i in range(10): output3 = torch.matmul(x3, w3) loss3 = criterion(output3, y3) loss3 /= 10 loss3.backward() print('Accumulated gradient: {}'.format(w3.grad)) ================================================ FILE: adaptive_batchnorm.py ================================================ """ Implementation of Adaptive BatchNorm @author: ptrblck """ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchvision import datasets, transforms # Globals device = 'cuda' if torch.cuda.is_available() else 'cpu' seed = 2809 batch_size = 10 lr = 0.01 log_interval = 10 epochs = 10 torch.manual_seed(seed) class AdaptiveBatchNorm2d(nn.Module): ''' Adaptive BN implementation using two additional parameters: out = a * x + b * bn(x) ''' def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(AdaptiveBatchNorm2d, self).__init__() self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine) self.a = nn.Parameter(torch.FloatTensor(1, 1, 1, 1)) self.b = nn.Parameter(torch.FloatTensor(1, 1, 1, 1)) def forward(self, x): return self.a * x + self.b * self.bn(x) class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5) self.conv1_bn = AdaptiveBatchNorm2d(10) self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5) self.conv2_bn = AdaptiveBatchNorm2d(20) self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)), 2)) x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def test(): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) # sum up batch loss test_loss += F.nll_loss(output, target, size_average=False).item() # get the index of the max log-probability pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) train_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True) model = MyNet().to(device) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5) for epoch in range(1, epochs + 1): train(epoch) test() ================================================ FILE: adaptive_pooling_torchvision.py ================================================ """ Adaptive pooling layer examples @author: ptrblck """ import torch import torch.nn as nn import torchvision.models as models # Use standard model with [batch_size, 3, 224, 224] input model = models.vgg16(pretrained=False) batch_size = 1 x = torch.randn(batch_size, 3, 224, 224) output = model(x) # Try bigger input x_big = torch.randn(batch_size, 3, 299, 299) try: output = model(x_big) except RuntimeError as e: print(e) # Try smaller input x_small = torch.randn(batch_size, 3, 128, 128) try: output = model(x_small) except RuntimeError as e: print(e) # Both don't work, since we get a size mismatch for these sizes # Get the size of the last activation map before the classifier def size_hook(module, input, output): print(output.shape) model.features[-1].register_forward_hook(size_hook) output = model(x) # We see that the last pooling layer returns an activation of # [batch_size, 512, 7, 7]. So let's replace it with an adaptive layer with an # output shape of 7x7. model.features[-1] = nn.AdaptiveMaxPool2d(output_size=7) # Now let's try the other shapes again output = model(x_big) output = model(x_small) x_tiny = torch.randn(batch_size, 3, 16, 16) output = model(x_tiny) # Now these inputs are working! # There is however a minimal size as we need a spatial size of at least 1x1 # to pass into the adaptive pooling layer x_too_small = torch.randn(batch_size, 3, 15, 15) try: output = model(x_too_small) except RuntimeError as e: print(e) ================================================ FILE: batch_norm_manual.py ================================================ """ Comparison of manual BatchNorm2d layer implementation in Python and nn.BatchNorm2d @author: ptrblck """ import torch import torch.nn as nn def compare_bn(bn1, bn2): err = False if not torch.allclose(bn1.running_mean, bn2.running_mean): print('Diff in running_mean: {} vs {}'.format( bn1.running_mean, bn2.running_mean)) err = True if not torch.allclose(bn1.running_var, bn2.running_var): print('Diff in running_var: {} vs {}'.format( bn1.running_var, bn2.running_var)) err = True if bn1.affine and bn2.affine: if not torch.allclose(bn1.weight, bn2.weight): print('Diff in weight: {} vs {}'.format( bn1.weight, bn2.weight)) err = True if not torch.allclose(bn1.bias, bn2.bias): print('Diff in bias: {} vs {}'.format( bn1.bias, bn2.bias)) err = True if not err: print('All parameters are equal!') class MyBatchNorm2d(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(MyBatchNorm2d, self).__init__( num_features, eps, momentum, affine, track_running_stats) def forward(self, input): self._check_input_dim(input) exponential_average_factor = 0.0 if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum # calculate running estimates if self.training: mean = input.mean([0, 2, 3]) # use biased var in train var = input.var([0, 2, 3], unbiased=False) n = input.numel() / input.size(1) with torch.no_grad(): self.running_mean = exponential_average_factor * mean\ + (1 - exponential_average_factor) * self.running_mean # update running_var with unbiased var self.running_var = exponential_average_factor * var * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_var else: mean = self.running_mean var = self.running_var input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps)) if self.affine: input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None] return input # Init BatchNorm layers my_bn = MyBatchNorm2d(3, affine=True) bn = nn.BatchNorm2d(3, affine=True) compare_bn(my_bn, bn) # weight and bias should be different # Load weight and bias my_bn.load_state_dict(bn.state_dict()) compare_bn(my_bn, bn) # Run train for _ in range(10): scale = torch.randint(1, 10, (1,)).float() bias = torch.randint(-10, 10, (1,)).float() x = torch.randn(10, 3, 100, 100) * scale + bias out1 = my_bn(x) out2 = bn(x) compare_bn(my_bn, bn) torch.allclose(out1, out2) print('Max diff: ', (out1 - out2).abs().max()) # Run eval my_bn.eval() bn.eval() for _ in range(10): scale = torch.randint(1, 10, (1,)).float() bias = torch.randint(-10, 10, (1,)).float() x = torch.randn(10, 3, 100, 100) * scale + bias out1 = my_bn(x) out2 = bn(x) compare_bn(my_bn, bn) torch.allclose(out1, out2) print('Max diff: ', (out1 - out2).abs().max()) ================================================ FILE: change_crop_in_dataset.py ================================================ """ Change the crop size on the fly using a Dataset. MyDataset.set_state(stage) switches between crop sizes. Alternatively, the crop size could be specified. @author: ptrblck """ import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms import torchvision.transforms.functional as TF class MyDataset(Dataset): def __init__(self): self.images = [TF.to_pil_image(x) for x in torch.ByteTensor(10, 3, 48, 48)] self.set_stage(0) def __getitem__(self, index): image = self.images[index] # Switch your behavior depending on stage image = self.crop(image) x = TF.to_tensor(image) return x def set_stage(self, stage): if stage == 0: print('Using (32, 32) crops') self.crop = transforms.RandomCrop((32, 32)) elif stage == 1: print('Using (28, 28) crops') self.crop = transforms.RandomCrop((28, 28)) def __len__(self): return len(self.images) dataset = MyDataset() loader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=True) # Use standard crop size for batch_idx, data in enumerate(loader): print('Batch idx {}, data shape {}'.format( batch_idx, data.shape)) # Switch to stage1 crop size loader.dataset.set_stage(1) # Check the shape again for batch_idx, data in enumerate(loader): print('Batch idx {}, data shape {}'.format( batch_idx, data.shape)) ================================================ FILE: channel_to_patches.py ================================================ """ Permute image data so that channel values of each pixel are flattened to an image patch around the pixel. @author: ptrblck """ import torch B, C, H, W = 2, 16, 4, 4 # Create dummy input with same values in each channel x = torch.arange(C)[None, :, None, None].repeat(B, 1, H, W) print(x) # Permute channel dimension to last position and view as 4x4 windows x = x.permute(0, 2, 3, 1).view(B, H, W, 4, 4) print(x) # Permute "window dims" with spatial dims, view as desired output x = x.permute(0, 1, 3, 2, 4).contiguous().view(B, 1, 4*H, 4*W) print(x) ================================================ FILE: conv_rnn.py ================================================ """ Combine Conv3d with an RNN Module. Use windowed frames as inputs. @author: ptrblck """ import torch import torch.nn as nn from torch.utils.data import Dataset class MyModel(nn.Module): def __init__(self, window=16): super(MyModel, self).__init__() self.conv_model = nn.Sequential( nn.Conv3d( in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1), nn.MaxPool3d((1, 2, 2)), nn.ReLU() ) self.rnn = nn.RNN( input_size=6*16*12*12, hidden_size=1, num_layers=1, batch_first=True ) self.hidden = torch.zeros(1, 1, 1) self.window = window def forward(self, x): self.hidden = torch.zeros(1, 1, 1) # reset hidden activations = [] for idx in range(0, x.size(2), self.window): x_ = x[:, :, idx:idx+self.window] x_ = self.conv_model(x_) x_ = x_.view(x_.size(0), 1, -1) activations.append(x_) x = torch.cat(activations, 1) out, hidden = self.rnn(x, self.hidden) return out, hidden class MyDataset(Dataset): ''' Returns windowed frames from sequential data. ''' def __init__(self, frames=512): self.data = torch.randn(3, 2048, 24, 24) self.frames = frames def __getitem__(self, index): index = index * self.frames x = self.data[:, index:index+self.frames] return x def __len__(self): return self.data.size(1) / self.frames model = MyModel() dataset = MyDataset() x = dataset[0] output, hidden = model(x.unsqueeze(0)) ================================================ FILE: csv_chunk_read.py ================================================ """ Provide data chunks from continuous .csv file. @author: ptrblck """ import torch from torch.utils.data import Dataset, DataLoader import numpy as np import pandas as pd # Create dummy csv data nb_samples = 110 a = np.arange(nb_samples) df = pd.DataFrame(a, columns=['data']) df.to_csv('data.csv', index=False) # Create Dataset class CSVDataset(Dataset): def __init__(self, path, chunksize, nb_samples): self.path = path self.chunksize = chunksize self.len = nb_samples // self.chunksize def __getitem__(self, index): ''' Get next chunk of data ''' x = next( pd.read_csv( self.path, skiprows=index * self.chunksize + 1, # +1, since we skip the header chunksize=self.chunksize, names=['data'])) x = torch.from_numpy(x.data.values) return x def __len__(self): return self.len dataset = CSVDataset('data.csv', chunksize=10, nb_samples=nb_samples) loader = DataLoader(dataset, batch_size=10, num_workers=1, shuffle=False) for batch_idx, data in enumerate(loader): print('batch: {}\tdata: {}'.format(batch_idx, data)) ================================================ FILE: densenet_forwardhook.py ================================================ """ Use forward hooks to get intermediate activations from densenet121. Create additional conv layers to process these activations to get a desired number of output channels @author: ptrblck """ import torch import torch.nn as nn from torchvision import models activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output return hook # Create Model model = models.densenet121(pretrained=False) # Register forward hooks with name for name, child in model.features.named_children(): if 'denseblock' in name: print(name) child.register_forward_hook(get_activation(name)) # Forward pass x = torch.randn(1, 3, 224, 224) output = model(x) # Create convs to get desired out_channels out_channels = 1 convs = {'denseblock1': nn.Conv2d(256, out_channels, 1,), 'denseblock2': nn.Conv2d(512, out_channels, 1), 'denseblock3': nn.Conv2d(1024, out_channels, 1), 'denseblock4': nn.Conv2d(1024, out_channels, 1)} # Apply conv on each activation for key in activations: act = activations[key] act = convs[key](act) print(key, act.shape) ================================================ FILE: edge_weighting_segmentation.py ================================================ """ Apply weighting to edges for a segmentation task @author: ptrblck """ import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt # Create dummy input and target with two squares output = F.log_softmax(torch.randn(1, 3, 24, 24), 1) target = torch.zeros(1, 24, 24, dtype=torch.long) target[0, 4:12, 4:12] = 1 target[0, 14:20, 14:20] = 2 plt.imshow(target[0]) # Edge calculation # Get binary target bin_target = torch.where(target > 0, torch.tensor(1), torch.tensor(0)) plt.imshow(bin_target[0]) # Use average pooling to get edge o = F.avg_pool2d(bin_target.float(), kernel_size=3, padding=1, stride=1) plt.imshow(o[0]) edge_idx = (o.ge(0.01) * o.le(0.99)).float() plt.imshow(edge_idx[0]) # Create weight mask weights = torch.ones_like(edge_idx, dtype=torch.float32) weights_sum0 = weights.sum() # Save initial sum for later rescaling weights = weights + edge_idx * 2. # Weight edged with 2x loss weights_sum1 = weights.sum() weights = weights / weights_sum1 * weights_sum0 # Rescale weigths plt.imshow(weights[0]) # Calculate loss criterion = nn.NLLLoss(reduction='none') loss = criterion(output, target) loss = loss * weights # Apply weighting loss = loss.sum() / weights.sum() # Scale loss ================================================ FILE: image_rotation_with_matrix.py ================================================ """ Rotate image given an angle. 1. Calculate rotated position for each input pixel 2. Use meshgrid and rotation matrix to achieve the same @author: ptrblck """ import torch import numpy as np # Create dummy image batch_size = 1 im = torch.zeros(batch_size, 1, 10, 10) im[:, :, :, 2] = 1. # Set angle angle = torch.tensor([72 * np.pi / 180.]) # Calculate rotation for each target pixel x_mid = (im.size(2) + 1) / 2. y_mid = (im.size(3) + 1) / 2. im_rot = torch.zeros_like(im) for r in range(im.size(2)): for c in range(im.size(3)): x = (r - x_mid) * torch.cos(angle) + (c - y_mid) * torch.sin(angle) y = -1.0 * (r - x_mid) * torch.sin(angle) + (c - y_mid) * torch.cos(angle) x = torch.round(x) + x_mid y = torch.round(y) + y_mid if (x >= 0 and y >= 0 and x < im.size(2) and y < im.size(3)): im_rot[:, :, r, c] = im[:, :, x.long().item(), y.long().item()] # Calculate rotation with inverse rotation matrix rot_matrix = torch.tensor([[torch.cos(angle), torch.sin(angle)], [-1.0*torch.sin(angle), torch.cos(angle)]]) # Use meshgrid for pixel coords xv, yv = torch.meshgrid(torch.arange(im.size(2)), torch.arange(im.size(3))) xv = xv.contiguous() yv = yv.contiguous() src_ind = torch.cat(( (xv.float() - x_mid).view(-1, 1), (yv.float() - y_mid).view(-1, 1)), dim=1 ) # Calculate indices using rotation matrix src_ind = torch.matmul(src_ind, rot_matrix.t()) src_ind = torch.round(src_ind) src_ind += torch.tensor([[x_mid, y_mid]]) # Set out of bounds indices to limits src_ind[src_ind < 0] = 0. src_ind[:, 0][src_ind[:, 0] >= im.size(2)] = float(im.size(2)) - 1 src_ind[:, 1][src_ind[:, 1] >= im.size(3)] = float(im.size(3)) - 1 # Create new rotated image im_rot2 = torch.zeros_like(im) src_ind = src_ind.long() im_rot2[:, :, xv.view(-1), yv.view(-1)] = im[:, :, src_ind[:, 0], src_ind[:, 1]] im_rot2 = im_rot2.view(batch_size, 1, 10, 10) print('Using method 1: {}'.format(im_rot)) print('Using method 2: {}'.format(im_rot2)) ================================================ FILE: mnist_autoencoder.py ================================================ """ Simple autoencoder for MNIST data. Visualizes some output images, intermediate activations as well as some conv kernels. @author: ptrblck """ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader import torchvision.transforms as transforms import torchvision.datasets as datasets import matplotlib.pyplot as plt class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv1 = nn.Conv2d(1, 3, 3, 1, 1) self.pool1 = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(3, 6, 3, 1, 1) self.pool2 = nn.MaxPool2d(2) self.conv_trans1 = nn.ConvTranspose2d(6, 3, 4, 2, 1) self.conv_trans2 = nn.ConvTranspose2d(3, 1, 4, 2, 1) def forward(self, x): x = F.relu(self.pool1(self.conv1(x))) x = F.relu(self.pool2(self.conv2(x))) x = F.relu(self.conv_trans1(x)) x = self.conv_trans2(x) return x dataset = datasets.MNIST( root='./data', transform=transforms.ToTensor() ) loader = DataLoader( dataset, num_workers=2, batch_size=8, shuffle=True ) model = MyModel() criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) epochs = 1 for epoch in range(epochs): for batch_idx, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = criterion(output, data) loss.backward() optimizer.step() print('Epoch {}, Batch idx {}, loss {}'.format( epoch, batch_idx, loss.item())) def normalize_output(img): img = img - img.min() img = img / img.max() return img # Plot some images idx = torch.randint(0, output.size(0), ()) pred = normalize_output(output[idx, 0]) img = data[idx, 0] fig, axarr = plt.subplots(1, 2) axarr[0].imshow(img.detach().numpy()) axarr[1].imshow(pred.detach().numpy()) # Visualize feature maps activation = {} def get_activation(name): def hook(model, input, output): activation[name] = output.detach() return hook model.conv1.register_forward_hook(get_activation('conv1')) data, _ = dataset[0] data.unsqueeze_(0) output = model(data) act = activation['conv1'].squeeze() fig, axarr = plt.subplots(act.size(0)) for idx in range(act.size(0)): axarr[idx].imshow(act[idx]) # Visualize conv filter kernels = model.conv1.weight.detach() fig, axarr = plt.subplots(kernels.size(0)) for idx in range(kernels.size(0)): axarr[idx].imshow(kernels[idx].squeeze()) ================================================ FILE: mnist_permuted.py ================================================ """ Permute all pixels of MNIST data and try to learn it using simple model. @author: ptrblck """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torch.nn.functional as F from torchvision import datasets from torchvision import transforms import numpy as np # Create random indices to permute images indices = np.arange(28*28) np.random.shuffle(indices) def shuffle_image(tensor): tensor = tensor.view(-1)[indices].view(1, 28, 28) return tensor # Apply permuatation using transforms.Lambda train_dataset = datasets.MNIST(root='./data', download=False, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda(shuffle_image) ])) test_dataset = datasets.MNIST(root='./data', download=False, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda(shuffle_image) ])) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.act = nn.ReLU() self.conv1 = nn.Conv2d(1, 4, 3, 1, 1) self.pool1 = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(4, 8, 3, 1, 1) self.pool2 = nn.MaxPool2d(2) self.fc1 = nn.Linear(7*7*8, 10) def forward(self, x): x = self.act(self.conv1(x)) x = self.pool1(x) x = self.act(self.conv2(x)) x = self.pool2(x) x = x.view(x.size(0), -1) x = F.log_softmax(self.fc1(x), dim=1) return x def train(): acc = 0.0 for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() _, pred = torch.max(output, dim=1) accuracy = (pred == target).sum() / float(pred.size(0)) acc += accuracy.data.float() if (batch_idx + 1) % 10 == 0: print('batch idx {}, loss {}'.format( batch_idx, loss.item())) acc /= len(train_loader) print('Train accuracy {}'.format(acc)) def test(): acc = 0.0 losses = 0.0 for batch_idx, (data, target) in enumerate(test_loader): with torch.no_grad(): output = model(data) loss = criterion(output, target) _, pred = torch.max(output, dim=1) accuracy = (pred == target).sum() / float(pred.size(0)) acc += accuracy.data.float() losses += loss.item() acc /= len(test_loader) losses /= len(test_loader) print('Acc {}, loss {}'.format( acc, losses)) model = MyModel() criterion = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) train() test() # Visualize filters import matplotlib.pyplot as plt from torchvision.utils import make_grid filts1 = model.conv1.weight.data grid = make_grid(filts1) grid = grid.permute(1, 2, 0) plt.imshow(grid) ================================================ FILE: model_sharding_data_parallel.py ================================================ """ Model sharding with DataParallel using 2 pairs of 2 GPUs. @author: ptrblck """ import torch import torch.nn as nn class SubModule(nn.Module): def __init__(self, in_channels, out_channels): super(SubModule, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) def forward(self, x): print('SubModule, device: {}, shape: {}\n'.format(x.device, x.shape)) x = self.conv1(x) return x class MyModel(nn.Module): def __init__(self, split_gpus, parallel): super(MyModel, self).__init__() self.module1 = SubModule(3, 6) self.module2 = SubModule(6, 1) self.split_gpus = split_gpus self.parallel = parallel if self.split_gpus and self.parallel: self.module1 = nn.DataParallel(self.module1, device_ids=[0, 1]).to('cuda:0') self.module2 = nn.DataParallel(self.module2, device_ids=[2, 3]).to('cuda:2') def forward(self, x): print('Input: device {}, shape {}\n'.format(x.device, x.shape)) x = self.module1(x) print('After module1: device {}, shape {}\n'.format(x.device, x.shape)) x = self.module2(x) print('After module2: device {}, shape {}\n'.format(x.device, x.shape)) return x model = MyModel(split_gpus=True, parallel=True) x = torch.randn(16, 3, 24, 24).to('cuda:0') output = model(x) ================================================ FILE: momentum_update_nograd.py ================================================ """ Script to see how parameters are updated when an optimizer is used with momentum/running estimates, even if gradients are zero. Set use_adam=True to see the effect. Otherwise plain SGD will be used. The model consists of two "decoder" parts, dec1 and dec2. In the first part of the script, you'll see that dec1 will be updated twice, even though this module is not used in the second forward pass. This effect is observed, if one optimizer is used for all parameters. In the second part of the script, two separate optimizers are used and we cannot observe this effect anymore. @author: ptrblck """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim use_adam = True class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.enc = nn.Linear(64, 10) self.dec1 = nn.Linear(10, 64) self.dec2 = nn.Linear(10, 64) def forward(self, x, decoder_idx): x = F.relu(self.enc(x)) if decoder_idx == 1: print('Using dec1') x = self.dec1(x) elif decoder_idx == 2: print('Using dec2') x = self.dec2(x) else: print('Unknown decoder_idx') return x # Create input and model x = torch.randn(1, 64) y = x.clone() model = MyModel() criterion = nn.MSELoss() # Create optimizer using all model parameters if use_adam: optimizer = optim.Adam(model.parameters(), lr=1.) else: optimizer = optim.SGD(model.parameters(), lr=1.) # Save init values old_state_dict = {} for key in model.state_dict(): old_state_dict[key] = model.state_dict()[key].clone() # Training procedure optimizer.zero_grad() output = model(x, 1) loss = criterion(output, y) loss.backward() # Check for gradients in dec1, dec2 print('Dec1 grad: {}\nDec2 grad: {}'.format( model.dec1.weight.grad, model.dec2.weight.grad)) optimizer.step() # Save new params new_state_dict = {} for key in model.state_dict(): new_state_dict[key] = model.state_dict()[key].clone() # Compare params for key in old_state_dict: if not (old_state_dict[key] == new_state_dict[key]).all(): print('Diff in {}'.format(key)) # Update old_state_dict = {} for key in model.state_dict(): old_state_dict[key] = model.state_dict()[key].clone() # Pass through dec2 optimizer.zero_grad() output = model(x, 2) loss = criterion(output, y) loss.backward() print('Dec1 grad: {}\nDec2 grad: {}'.format( model.dec1.weight.grad, model.dec2.weight.grad)) optimizer.step() # Save new params new_state_dict = {} for key in model.state_dict(): new_state_dict[key] = model.state_dict()[key].clone() # Compare params for key in old_state_dict: if not (old_state_dict[key] == new_state_dict[key]).all(): print('Diff in {}'.format(key)) ## Create separate optimizers model = MyModel() dec1_params = list(model.enc.parameters()) + list(model.dec1.parameters()) optimizer1 = optim.Adam(dec1_params, lr=1.) dec2_params = list(model.enc.parameters()) + list(model.dec2.parameters()) optimizer2 = optim.Adam(dec2_params, lr=1.) # Save init values old_state_dict = {} for key in model.state_dict(): old_state_dict[key] = model.state_dict()[key].clone() # Training procedure optimizer1.zero_grad() output = model(x, 1) loss = criterion(output, y) loss.backward() # Check for gradients in dec1, dec2 print('Dec1 grad: {}\nDec2 grad: {}'.format( model.dec1.weight.grad, model.dec2.weight.grad)) optimizer1.step() # Save new params new_state_dict = {} for key in model.state_dict(): new_state_dict[key] = model.state_dict()[key].clone() # Compare params for key in old_state_dict: if not (old_state_dict[key] == new_state_dict[key]).all(): print('Diff in {}'.format(key)) # Update old_state_dict = {} for key in model.state_dict(): old_state_dict[key] = model.state_dict()[key].clone() # Pass through dec2 optimizer1.zero_grad() output = model(x, 2) loss = criterion(output, y) loss.backward() print('Dec1 grad: {}\nDec2 grad: {}'.format( model.dec1.weight.grad, model.dec2.weight.grad)) optimizer2.step() # Save new params new_state_dict = {} for key in model.state_dict(): new_state_dict[key] = model.state_dict()[key].clone() # Compare params for key in old_state_dict: if not (old_state_dict[key] == new_state_dict[key]).all(): print('Diff in {}'.format(key)) ================================================ FILE: pytorch_redis.py ================================================ """ Shows how to store and load data from redis using a PyTorch Dataset and DataLoader (with multiple workers). @author: ptrblck """ import redis import torch from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms import numpy as np # Create random data and push to redis r = redis.Redis(host='localhost', port=6379, db=0) nb_images = 100 for idx in range(nb_images): # Use long for the fake images, as it's easier to store the target with it data = np.random.randint(0, 256, (3, 24, 24), dtype=np.long).tobytes() target = bytes(np.random.randint(0, 10, (1,)).astype(np.long)) r.set(idx, data + target) # Create RedisDataset class RedisDataset(Dataset): def __init__(self, redis_host='localhost', redis_port=6379, redis_db=0, length=0, transform=None): self.db = redis.Redis(host=redis_host, port=redis_port, db=redis_db) self.length = length self.transform = transform def __getitem__(self, index): data = self.db.get(index) data = np.frombuffer(data, dtype=np.long) x = data[:-1].reshape(3, 24, 24).astype(np.uint8) y = torch.tensor(data[-1]).long() if self.transform: x = self.transform(x) return x, y def __len__(self): return self.length # Load samples from redis using multiprocessing dataset = RedisDataset(length=100, transform=transforms.ToTensor()) loader = DataLoader( dataset, batch_size=10, num_workers=2, shuffle=True ) for data, target in loader: print(data.shape) print(target.shape) ================================================ FILE: shared_array.py ================================================ """ Script to demonstrate the usage of shared arrays using multiple workers. In the first epoch the shared arrays in the dataset will be filled with random values. After setting set_use_cache(True), the shared values will be loaded from multiple processes. @author: ptrblck """ import torch from torch.utils.data import Dataset, DataLoader import ctypes import multiprocessing as mp import numpy as np class MyDataset(Dataset): def __init__(self): shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w) shared_array = np.ctypeslib.as_array(shared_array_base.get_obj()) shared_array = shared_array.reshape(nb_samples, c, h, w) self.shared_array = torch.from_numpy(shared_array) self.use_cache = False def set_use_cache(self, use_cache): self.use_cache = use_cache def __getitem__(self, index): if not self.use_cache: print('Filling cache for index {}'.format(index)) # Add your loading logic here self.shared_array[index] = torch.randn(c, h, w) x = self.shared_array[index] return x def __len__(self): return nb_samples nb_samples, c, h, w = 10, 3, 24, 24 dataset = MyDataset() loader = DataLoader( dataset, num_workers=2, shuffle=False ) for epoch in range(2): for idx, data in enumerate(loader): print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.shape)) if epoch == 0: loader.dataset.set_use_cache(True) ================================================ FILE: shared_dict.py ================================================ """ Script to demonstrate the usage of shared dicts using multiple workers. In the first epoch the shared dict in the dataset will be filled with random values. The next epochs will just use the dict without "loading" the data again. @author: ptrblck """ from multiprocessing import Manager import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, shared_dict, length): self.shared_dict = shared_dict self.length = length def __getitem__(self, index): if index not in self.shared_dict: print('Adding {} to shared_dict'.format(index)) self.shared_dict[index] = torch.tensor(index) return self.shared_dict[index] def __len__(self): return self.length # Init manager = Manager() shared_dict = manager.dict() dataset = MyDataset(shared_dict, length=100) loader = DataLoader( dataset, batch_size=10, num_workers=6, shuffle=True, pin_memory=True ) # First loop will add data to the shared_dict for x in loader: print(x) # The second loop will just get the data for x in loader: print(x) ================================================ FILE: unet_demo.py ================================================ """ Simple UNet demo @author: ptrblck """ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F class BaseConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, padding, stride): super(BaseConv, self).__init__() self.act = nn.ReLU() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding, stride) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding, stride) def forward(self, x): x = self.act(self.conv1(x)) x = self.act(self.conv2(x)) return x class DownConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, padding, stride): super(DownConv, self).__init__() self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv_block = BaseConv(in_channels, out_channels, kernel_size, padding, stride) def forward(self, x): x = self.pool1(x) x = self.conv_block(x) return x class UpConv(nn.Module): def __init__(self, in_channels, in_channels_skip, out_channels, kernel_size, padding, stride): super(UpConv, self).__init__() self.conv_trans1 = nn.ConvTranspose2d( in_channels, in_channels, kernel_size=2, padding=0, stride=2) self.conv_block = BaseConv( in_channels=in_channels + in_channels_skip, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride) def forward(self, x, x_skip): x = self.conv_trans1(x) x = torch.cat((x, x_skip), dim=1) x = self.conv_block(x) return x class UNet(nn.Module): def __init__(self, in_channels, out_channels, n_class, kernel_size, padding, stride): super(UNet, self).__init__() self.init_conv = BaseConv(in_channels, out_channels, kernel_size, padding, stride) self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size, padding, stride) self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size, padding, stride) self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size, padding, stride) self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels, kernel_size, padding, stride) self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels, kernel_size, padding, stride) self.up1 = UpConv(2 * out_channels, out_channels, out_channels, kernel_size, padding, stride) self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride) def forward(self, x): # Encoder x = self.init_conv(x) x1 = self.down1(x) x2 = self.down2(x1) x3 = self.down3(x2) # Decoder x_up = self.up3(x3, x2) x_up = self.up2(x_up, x1) x_up = self.up1(x_up, x) x_out = F.log_softmax(self.out(x_up), 1) return x_out # Create 10-class segmentation dummy image and target nb_classes = 10 x = torch.randn(1, 3, 96, 96) y = torch.randint(0, nb_classes, (1, 96, 96)) model = UNet(in_channels=3, out_channels=64, n_class=10, kernel_size=3, padding=1, stride=1) if torch.cuda.is_available(): model = model.to('cuda') x = x.to('cuda') y = y.to('cuda') criterion = nn.NLLLoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) # Training loop for epoch in range(1): optimizer.zero_grad() output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() print('Epoch {}, Loss {}'.format(epoch, loss.item())) ================================================ FILE: weighted_sampling.py ================================================ """ Usage of WeightedRandomSampler using an imbalanced dataset with class imbalance 99 to 1. @author: ptrblck """ import torch from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.dataloader import DataLoader # Create dummy data with class imbalance 99 to 1 numDataPoints = 1000 data_dim = 5 bs = 100 data = torch.randn(numDataPoints, data_dim) target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long), torch.ones(int(numDataPoints * 0.01), dtype=torch.long))) print('target train 0/1: {}/{}'.format( (target == 0).sum(), (target == 1).sum())) # Compute samples weight (each sample should get its own weight) class_sample_count = torch.tensor( [(target == t).sum() for t in torch.unique(target, sorted=True)]) weight = 1. / class_sample_count.float() samples_weight = torch.tensor([weight[t] for t in target]) # Create sampler, dataset, loader sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) train_dataset = torch.utils.data.TensorDataset(data, target) train_loader = DataLoader( train_dataset, batch_size=bs, num_workers=1, sampler=sampler) # Iterate DataLoader and check class balance for each batch for i, (x, y) in enumerate(train_loader): print("batch index {}, 0/1: {}/{}".format( i, (y == 0).sum(), (y == 1).sum()))