[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Nikhil Vyas\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# SOAP\n\nThis is the official (preliminary) implementation of the SOAP optimizer from [SOAP: Improving and Stabilizing Shampoo using Adam](https://arxiv.org/abs/2409.11321). To use, copy the soap.py file to your codebase and use SOAP optimizer in the following fashion:\n\n```\nfrom soap import SOAP\n\noptim = SOAP(lr = 3e-3, betas=(.95, .95), weight_decay=.01, precondition_frequency=10)\n```\n\nWe recommend trying it with as large batch size as possible, as expected from second order optimizers, the benefits are larger at larger batch sizes.\n\nWhile in the paper our experiments are restricted to Transformers which only have 2D layers, the code supports nD layers. If you are using the optimizer for (n > 2) nD layers please see additional hyperparameters in soap.py.\n\n\nWe will release an improved version of the optimizer with support for lower precision and distributed training. \n\n\nHaydn Jones has implemented a JAX version at https://github.com/haydn-jones/SOAP_JAX, though we have not yet verified the implementation.\n"
  },
  {
    "path": "soap.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nfrom itertools import chain\n\n# Parts of the code are modifications of Pytorch's AdamW optimizer\n# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py\n\n\nclass SOAP(optim.Optimizer):\n    \"\"\"\n    Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).\n\n    Parameters:\n        params (`Iterable[nn.parameter.Parameter]`):\n            Iterable of parameters to optimize or dictionaries defining parameter groups.\n        lr (`float`, *optional*, defaults to 0.003):\n            The learning rate to use.\n        betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):\n            Adam's betas parameters (b1, b2).\n        shampoo_beta (`float`, *optional*, defaults to -1):\n            If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].\n        eps (`float`, *optional*, defaults to 1e-08):\n            Adam's epsilon for numerical stability.\n        weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.\n        precondition_frequency (`int`, *optional*, defaults to 10):\n            How often to update the preconditioner.\n        max_precond_dim (`int`, *optional*, defaults to 10000):\n            Maximum dimension of the preconditioner.\n            Set to 10000, so that we exclude most common vocab sizes while including layers.\n        merge_dims (`bool`, *optional*, defaults to `False`):\n            Whether or not to merge dimensions of the preconditioner.\n        precondition_1d (`bool`, *optional*, defaults to `False`):\n            Whether or not to precondition 1D gradients.\n        normalize_grads (`bool`, *optional*, defaults to `False`):\n            Whether or not to normalize gradients per layer. \n            Helps at large precondition_frequency (~100 in our experiments), \n            but hurts performance at small precondition_frequency (~10 in our experiments).\n        data_format (`str`, *optional*, defaults to `channels_first`):\n            Data format of the input for convolutional layers.\n            Should be \"channels_last\" for data_format of NHWC and \"channels_first\" for NCHW.\n        correct_bias (`bool`, *optional*, defaults to `True`):\n            Whether or not to use bias correction in Adam.\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr: float = 3e-3,\n        betas=(0.95, 0.95),\n        shampoo_beta: float= -1,\n        eps: float = 1e-8,\n        weight_decay: float = 0.01,\n        precondition_frequency: int=10,\n        max_precond_dim: int=10000, # \n        merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.\n        precondition_1d: bool = False,\n        normalize_grads: bool = False,\n        data_format: str = \"channels_first\",\n        correct_bias: bool = True,\n    ):\n        defaults = {\n            \"lr\": lr,\n            \"betas\": betas,\n            \"shampoo_beta\": shampoo_beta,\n            \"eps\": eps,\n            \"weight_decay\": weight_decay,\n            \"precondition_frequency\": precondition_frequency,\n            \"max_precond_dim\": max_precond_dim,\n            \"merge_dims\": merge_dims,\n            \"precondition_1d\": precondition_1d,\n            \"normalize_grads\": normalize_grads,\n            \"correct_bias\": correct_bias,\n        }\n        super().__init__(params, defaults)\n        self._data_format = data_format\n        \n    def merge_dims(self, grad, max_precond_dim):\n        \"\"\"\n        Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.\n        \"\"\"\n        assert self._data_format in [\"channels_first\", \"channels_last\"]\n        if self._data_format == \"channels_last\" and grad.dim() == 4:\n            grad = grad.permute(0, 3, 1, 2)\n        shape = grad.shape\n        new_shape = []\n        \n        curr_shape = 1\n        for sh in shape:\n            temp_shape = curr_shape * sh\n            if temp_shape > max_precond_dim:\n                if curr_shape > 1:\n                    new_shape.append(curr_shape)\n                    curr_shape = sh\n                else:\n                    new_shape.append(sh)\n                    curr_shape = 1\n            else:\n                curr_shape = temp_shape\n        \n        if curr_shape > 1 or len(new_shape)==0:\n            new_shape.append(curr_shape)\n        \n        new_grad = grad.reshape(new_shape)\n        return new_grad               \n\n    @torch.no_grad()\n    def step(self, closure = None):\n        \"\"\"\n        Performs a single optimization step.\n\n        Arguments:\n            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.\n        \"\"\"\n        if closure is None:\n            loss = None\n        else:\n            loss = closure()\n        \n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad\n\n                state = self.state[p]\n                \n                if \"step\" not in state:\n                    state[\"step\"] = 0 \n                    \n                # State initialization\n                if \"exp_avg\" not in state:\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(grad)\n                    # Exponential moving average of squared gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros_like(grad)\n                \n                if 'Q' not in state:\n                    self.init_preconditioner(\n                        grad,\n                        state,\n                        precondition_frequency=group['precondition_frequency'],\n                        precondition_1d=group['precondition_1d'],\n                        shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group[\"betas\"][1]),\n                        max_precond_dim=group['max_precond_dim'],\n                        merge_dims=group[\"merge_dims\"],\n                    )\n                    self.update_preconditioner(grad, state,\n                                               max_precond_dim=group['max_precond_dim'],\n                                               merge_dims=group[\"merge_dims\"],\n                                               precondition_1d=group[\"precondition_1d\"])\n                    continue # first step is skipped so that we never use the current gradients in the projection.\n                \n                # Projecting gradients to the eigenbases of Shampoo's preconditioner \n                # i.e. projecting to the eigenbases of matrices in state['GG']\n                grad_projected = self.project(grad, state, merge_dims=group[\"merge_dims\"], \n                                              max_precond_dim=group['max_precond_dim'])\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                state[\"step\"] += 1\n\n                # Decay the first and second moment running average coefficient\n                # In-place operations to update the averages at the same time\n                exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))\n                exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2))\n\n                denom = exp_avg_sq.sqrt().add_(group[\"eps\"])\n                \n                # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner \n                # i.e. projecting to the eigenbases of matrices in state['GG']\n                # exp_avg_projected = self.project(exp_avg, state, merge_dims=group[\"merge_dims\"],\n                #                                  max_precond_dim=group['max_precond_dim'])\n                exp_avg_projected = exp_avg\n                \n                step_size = group[\"lr\"]\n                if group[\"correct_bias\"]:\n                    bias_correction1 = 1.0 - beta1 ** (state[\"step\"])\n                    bias_correction2 = 1.0 - beta2 ** (state[\"step\"])\n                    step_size = step_size * (bias_correction2 ** .5) / bias_correction1\n\n                # Projecting back the preconditioned (by Adam) exponential moving average of gradients\n                # to the original space\n                norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group[\"merge_dims\"],\n                                                 max_precond_dim=group['max_precond_dim'])\n\n                if group[\"normalize_grads\"]:\n                    norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5)\n                \n                p.add_(norm_grad, alpha=-step_size)\n                \n\n                # From AdamW code: Just adding the square of the weights to the loss function is *not*\n                # the correct way of using L2 regularization/weight decay with Adam,\n                # since that will interact with the m and v parameters in strange ways.\n                #\n                # Instead we want to decay the weights in a manner that doesn't interact\n                # with the m/v parameters. This is equivalent to adding the square\n                # of the weights to the loss with plain (non-momentum) SGD.\n                # Add weight decay at the end (fixed version)\n                if group[\"weight_decay\"] > 0.0:\n                    p.add_(p, alpha=(-group[\"lr\"] * group[\"weight_decay\"]))\n                    \n                # Update is done after the gradient step to avoid using current gradients in the projection.\n                self.update_preconditioner(grad, state, \n                                               max_precond_dim=group['max_precond_dim'],\n                                               merge_dims=group[\"merge_dims\"],\n                                               precondition_1d=group[\"precondition_1d\"])\n        \n        return loss\n    \n    def init_preconditioner(self, grad, state, precondition_frequency=10, \n                            shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,\n                            merge_dims=False):\n        \"\"\"\n        Initializes the preconditioner matrices (L and R in the paper).\n        \"\"\"\n        state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).\n        if grad.dim() == 1:\n            if not precondition_1d or grad.shape[0] > max_precond_dim:\n                state['GG'].append([])\n            else:\n                state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))\n        else:\n            if merge_dims:\n                grad = self.merge_dims(grad, max_precond_dim)\n\n            for sh in grad.shape:\n                if sh > max_precond_dim:\n                    state['GG'].append([])\n                else:\n                    state['GG'].append(torch.zeros(sh, sh, device=grad.device))\n                    \n        state['Q'] = None # Will hold all the eigenbases of the preconditioner.\n        state['precondition_frequency'] = precondition_frequency\n        state['shampoo_beta'] = shampoo_beta          \n        \n    def project(self, grad, state, merge_dims=False, max_precond_dim=10000):\n        \"\"\"\n        Projects the gradient to the eigenbases of the preconditioner.\n        \"\"\"\n        original_shape = grad.shape\n        if merge_dims:\n            if grad.dim() == 4 and self._data_format == 'channels_last':\n                permuted_shape = grad.permute(0, 3, 1, 2).shape\n            grad = self.merge_dims(grad, max_precond_dim)\n\n        for mat in state['Q']:\n            if len(mat) > 0:\n                grad = torch.tensordot(\n                        grad,\n                        mat,\n                        dims=[[0], [0]],\n                    )\n            else:\n                permute_order = list(range(1, len(grad.shape))) + [0]\n                grad = grad.permute(permute_order)\n        \n        if merge_dims:\n            if self._data_format == 'channels_last' and len(original_shape) == 4:\n                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)\n            else:\n                grad = grad.reshape(original_shape)\n        return grad\n        \n    def update_preconditioner(self, grad, state, \n                              max_precond_dim=10000, merge_dims=False, precondition_1d=False):\n        \"\"\"\n        Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).\n        \"\"\"\n        if state[\"Q\"] is not None:\n            state[\"exp_avg\"] = self.project_back(state[\"exp_avg\"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)\n        if grad.dim() == 1:\n            if precondition_1d and grad.shape[0] <= max_precond_dim:\n                state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])\n        else:\n            if merge_dims:\n                new_grad = self.merge_dims(grad, max_precond_dim)\n                for idx, sh in enumerate(new_grad.shape):\n                    if sh <= max_precond_dim:\n                        outer_product = torch.tensordot(\n                                new_grad,\n                                new_grad,\n                                dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,\n                            )\n                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])\n            else:\n                for idx, sh in enumerate(grad.shape):\n                    if sh <= max_precond_dim:\n                        outer_product = torch.tensordot(\n                                grad,\n                                grad,\n                                # Contracts across all dimensions except for k.\n                                dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,\n                            )\n                        state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])\n                     \n        if state['Q'] is None:\n            state['Q'] = self.get_orthogonal_matrix(state['GG'])\n        if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:\n            state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)\n            # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)             \n\n        if state[\"step\"] > 0:\n            state[\"exp_avg\"] = self.project(state[\"exp_avg\"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) \n\n    def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):\n        \"\"\"\n        Projects the gradient back to the original space.\n        \"\"\"\n        original_shape = grad.shape\n        if merge_dims:\n            if self._data_format == 'channels_last' and grad.dim() == 4:\n                permuted_shape = grad.permute(0, 3, 1, 2).shape\n            grad = self.merge_dims(grad, max_precond_dim)\n        for mat in state['Q']:\n            if len(mat) > 0:\n                grad = torch.tensordot(\n                        grad,\n                        mat,\n                        dims=[[0], [1]],\n                    )\n            else:\n                permute_order = list(range(1, len(grad.shape))) + [0]\n                grad = grad.permute(permute_order)\n                \n        if merge_dims:\n            if self._data_format == 'channels_last' and len(original_shape) == 4:\n                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)\n            else:\n                grad = grad.reshape(original_shape)\n        return grad\n        \n\n    def get_orthogonal_matrix(self, mat):\n        \"\"\"\n        Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.\n        \"\"\"\n        matrix = []\n        for m in mat:\n            if len(m) == 0:\n                matrix.append([])\n                continue\n            if m.data.dtype != torch.float:\n                float_data = False\n                original_type = m.data.dtype\n                original_device = m.data.device\n                matrix.append(m.data.float())\n            else:\n                float_data = True\n                matrix.append(m.data)\n        \n        final = []\n        for m in matrix:\n            if len(m) == 0:\n                final.append([])\n                continue\n            try:\n                _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device))\n            except:\n                _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device))\n                Q = Q.to(m.dtype)\n            Q = torch.flip(Q, [1])\n\n            if not float_data:\n                Q = Q.to(original_device).type(original_type)\n            final.append(Q)\n        return final\n        \n\n    def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):\n        \"\"\"\n        Computes the eigenbases of the preconditioner using one round of power iteration \n        followed by torch.linalg.qr decomposition.\n        \"\"\"\n        precond_list = state['GG']\n        orth_list = state['Q']\n\n        matrix = []\n        orth_matrix = []\n        for m,o in zip(precond_list, orth_list):\n            if len(m) == 0:\n                matrix.append([])\n                orth_matrix.append([])\n                continue\n            if m.data.dtype != torch.float:\n                float_data = False\n                original_type = m.data.dtype\n                original_device = m.data.device\n                matrix.append(m.data.float())\n                orth_matrix.append(o.data.float())\n            else:\n                float_data = True\n                matrix.append(m.data.float())\n                orth_matrix.append(o.data.float())\n        \n        orig_shape = state['exp_avg_sq'].shape\n        if self._data_format == 'channels_last' and len(orig_shape) == 4:\n            permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape\n        if merge_dims:\n            exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)\n        else:\n            exp_avg_sq = state['exp_avg_sq']\n            \n        final = []\n        for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):\n            if len(m)==0:\n                final.append([])\n                continue\n            est_eig = torch.diag(o.T @ m @ o)\n            sort_idx = torch.argsort(est_eig, descending=True)\n            exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)\n            o = o[:,sort_idx]\n            power_iter = m @ o\n            Q, _ = torch.linalg.qr(power_iter)\n\n            if not float_data:\n                Q = Q.to(original_device).type(original_type)\n            final.append(Q)\n        \n        if merge_dims:\n            if self._data_format == 'channels_last' and len(orig_shape) == 4:\n                exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)\n            else:\n                exp_avg_sq = exp_avg_sq.reshape(orig_shape)\n                \n        state['exp_avg_sq'] = exp_avg_sq\n        return final\n    \n    "
  }
]