[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Supertone Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Super-Monotonic-Alignment-Search\n\n[![TechnicalReport](https://img.shields.io/badge/TechnicalReport-2409.07704-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2409.07704) \n\nThis repo contains [Triton-Lang](https://github.com/triton-lang/triton) and PyTorch implementation of the monotonic alignment search (MAS), originally from [Glow-TTS](https://arxiv.org/abs/2005.11129).\nMAS is an effective algorithm for estimating the alignment between paired speech and text in a self-supervised manner.\n\n![Image0](./assets/memory_read_write.png)\n\n\nThe authors of Glow-TTS noted:\n> \"The time complexity of the algorithm is O(T_{text} × T_{mel}). Even though the algorithm is difficult to parallelize, it runs efficiently on CPU without the need for GPU executions. In our experiments, it spends less than 20 ms on each iteration, which amounts to less than 2% of the total training time. Furthermore, we do not need MAS during inference, as the duration predictor is used to estimate the alignment.\"\n\nHowever, we found three issues while using MAS.\n1. MAS can be parallelized in the text-length dimension, while the original implementation uses nested loops.\n2. CPU execution consumes an inordinate amount of time for large inputs due to the need to copy large tensors between CPU and GPU.\n3. The hard-coded value of max_neg_val at -1e9 is insufficient to prevent alignment mismatches in the upper diagonal parts.\n\nTherefore, we implemented a Triton kernel `super_monotonic_align` and PyTorch code `jit_monotonic_align` to accelerate MAS on GPU without inter-device copy.\n\n# Requirments\n1. PyTorch (tested with version `torch==2.3.0+cu121`)\n2. Triton-Lang (tested with version `triton==2.3.0`)\n3. Cython (optional for bench, tested with version `Cython== 0.29.36`)\n\nPlease ensure you have these packages installed to run the code in this repository, as version checks are not enforced.\n\n# How to use\n1. Install super-monotonic-align\n```\ngit clone git@github.com:supertone-inc/super-monotonic-align.git\ncd super-monotonic-align; pip install -e ./\n```\nor\n```\npip install git+https://github.com/supertone-inc/super-monotonic-align.git\n```\n2. Import `super_monotonic_align` and use it!\n```python\nfrom super_monotonic_align import maximum_path\n...\n# You need to know value's value is modified by triton kernel.\n# If you want to keep value without changing, you need to clone it before maximum_path.\n# B: batch_size, T: text_length, S: audio_length\nvalue = torch.randn((B, T, S), dtype=torch.float32, device='cuda')\nattn_mask = torch.ones((B, T, S), dtype=torch.int32, device='cuda')\n# path: [B,T,S] tensor, you can specify path's dtype, default=torch.float32\npath = maximum_path(value, attn_mask, dtype=torch.bool)\n```\n\n## Warning\n\nPlease **check your input shape** before use.\n\nThanks to [codeghees](https://github.com/codeghees) for the issue, our implementation uses the shape \\[B, T, S\\], identical to Glow-TTS version, while the VITS implementation uses the shape \\[B, S, T\\]. \n\nFor now, we recommend to transpose it if you using \\[B, S, T\\] shaped input, but we will soon release an option that supprots \\[B, S, T\\] as well.\n\n# Benchmark\n```\nMAS in ms:\n         T      Triton       JIT_v1       JIT_v2       Cython\n0    128.0    0.447488    83.742203    53.222176     8.819136\n1    256.0    1.616896   155.424774   104.632477    43.533665\n2    384.0    3.430400   325.307404   237.820435   136.257538\n3    512.0    5.838848   439.984131   344.654236   304.981201\n4    640.0    9.070592   532.910095   452.141907   462.405304\n5    768.0   12.249088   655.960083   587.169739   488.272858\n6    896.0   15.203328   557.997070   620.148315   863.919067\n7   1024.0   19.778561   627.986450   815.933167  1299.567871\n8   1152.0   33.276928   706.022400   968.533813  1467.056885\n9   1280.0   39.800835   792.861694  1215.021240  1930.171509\n10  1408.0   47.456257   903.750671  1289.656250  2231.598145\n11  1536.0   59.238914   953.907227  1523.870972  2959.377930\n12  1664.0   70.068741  1031.818237  2004.299438  3073.532471\n13  1792.0   82.205696  1558.200317  2359.347900  3930.776367\n14  1920.0   99.634689  1183.214600  2512.063477  4374.311035\n15  2048.0  107.218948  1261.682739  2889.841797  7792.640137\n```\n\nThe Triton MAS implementation is at least 19 times faster and up to 72 times faster than the Cython implementation. PyTorch JIT implementations are faster than the Cython implementation for large-sized tensors, especially version v1, which does not involve inter-device copying.\n\n| ms in linear scale | ms in log scale |\n|----------|----------|\n| ![Image 1](./assets/MAS.png) | ![Image 2](./assets/MAS_log.png) |\n\n## How to run benchmark\n```bash\ncd cython_monotonic_align; mkdir cython_monotonic_align; python setup.py build_ext --inplace\ncd ../super_monotonic_align; pip install -e ./\ncd ../; python test.py\n```\n\n# References\nThis implementation uses code from following repositories:\n- [jaywalnut310's Official Glow-TTS Implementation](https://github.com/jaywalnut310/glow-tts)\n- [OpenAI's Triton-Lang Tutorials](https://github.com/triton-lang/triton)\n- [Tri Dao's FlashAttention (memory hierarchy)](https://github.com/Dao-AILab/flash-attention)\n\n# Acknowledgement\nThis work is supported by Supertone Inc. and HYBE Corp. \nWe thank Jinhyeok Yang, Juheon Lee, Yechan Yu, Seunghoon Ji, Jacob Morton, Seungu Han, Sungho Lee, Joon Byun, and Hoon Heo of Supertone research team and Hyeong-Seok Choi of ElevenLabs.\n\n\n# Authors\n- Junhyeok Lee ([jlee843@jhu.edu](mailto:jlee843@jhu.edu))\n- Hyoungju Kim ([hyeongju@supertone.ai](mailto:hyeongju@supertone.ai))\n\nIf this repository useful for your research, please consider citing (with Glow-TTS or VITS)!\n```bib\n@article{supermas,\n  title={{Super Monotonic Alignment Search}},\n  author={Lee, Junhyeok and Kim, Hyeongju},\n  journal={arXiv preprint arXiv:2409.07704},\n  year={2024}\n}\n```\n\nFeel free to create an issue if you encounter any problems or have any questions.\n\nAdditionally, [Supertone](https://supertone.ai) is hiring TTS researchers. \nIf you are interested, please check out our career opportunities!\n"
  },
  {
    "path": "assets/MAS.csv",
    "content": "T,Triton,JIT_v1,JIT_v2,Cython\n128.0,0.4,83.7,53.2,8.8\n256.0,1.6,155.4,104.6,43.5\n384.0,3.4,325.3,237.8,136.3\n512.0,5.8,440.0,344.7,305.0\n640.0,9.1,532.9,452.1,462.4\n768.0,12.2,656.0,587.2,488.3\n896.0,15.2,558.0,620.1,863.9\n1024.0,19.8,628.0,815.9,1299.6\n1152.0,33.3,706.0,968.5,1467.1\n1280.0,39.8,792.9,1215.0,1930.2\n1408.0,47.5,903.8,1289.7,2231.6\n1536.0,59.2,953.9,1523.9,2959.4\n1664.0,70.1,1031.8,2004.3,3073.5\n1792.0,82.2,1558.2,2359.3,3930.8\n1920.0,99.6,1183.2,2512.1,4374.3\n2048.0,107.2,1261.7,2889.8,7792.6\n"
  },
  {
    "path": "cython_monotonic_align/__init__.py",
    "content": "# modified from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/__init__.py\nimport numpy as np\nimport torch\nfrom .cython_monotonic_align.core import maximum_path_c\n\n\ndef maximum_path(value, mask):  \n  \"\"\" Cython optimised version.\n  value: [b, t_x, t_y]\n  mask: [b, t_x, t_y]\n  \"\"\"\n  value = value * mask\n  device = value.device\n  dtype = value.dtype\n  value = value.data.cpu().numpy().astype(np.float32)\n  path = np.zeros_like(value).astype(np.int32)\n  mask = mask.data.cpu().numpy()\n\n  t_x_max = mask.sum(1)[:, 0].astype(np.int32)\n  t_y_max = mask.sum(2)[:, 0].astype(np.int32)\n  maximum_path_c(path, value, t_x_max, t_y_max)\n  return torch.from_numpy(path).to(device=device, dtype=dtype)"
  },
  {
    "path": "cython_monotonic_align/core.pyx",
    "content": "# copied from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/core.pyx\nimport numpy as np\ncimport numpy as np\ncimport cython\nfrom cython.parallel import prange\n\n\n@cython.boundscheck(False)\n@cython.wraparound(False)\ncdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:\n  cdef int x\n  cdef int y\n  cdef float v_prev\n  cdef float v_cur\n  cdef float tmp\n  cdef int index = t_x - 1\n\n  for y in range(t_y):\n    for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):\n      if x == y:\n        v_cur = max_neg_val\n      else:\n        v_cur = value[x, y-1]\n      if x == 0:\n        if y == 0:\n          v_prev = 0.\n        else:\n          v_prev = max_neg_val\n      else:\n        v_prev = value[x-1, y-1]\n      value[x, y] = max(v_cur, v_prev) + value[x, y]\n\n  for y in range(t_y - 1, -1, -1):\n    path[index, y] = 1\n    if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):\n      index = index - 1\n\n\n@cython.boundscheck(False)\n@cython.wraparound(False)\ncpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e32) nogil:\n  cdef int b = values.shape[0]\n\n  cdef int i\n  for i in prange(b, nogil=True):\n    maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)\n"
  },
  {
    "path": "cython_monotonic_align/setup.py",
    "content": "# modified from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/setup.py\n\nfrom distutils.core import setup\nfrom Cython.Build import cythonize\nimport numpy\n\nsetup(\n  name = 'cython_monotonic_align',\n  ext_modules = cythonize(\"core.pyx\"),\n  include_dirs=[numpy.get_include()]\n)"
  },
  {
    "path": "jit_monotonic_align/__init__.py",
    "content": "import torch\n\n\n@torch.no_grad()\n@torch.jit.script\ndef maximum_path1(logp: torch.Tensor, attn_mask: torch.Tensor):\n    # logp: [B, Tx, Ty], attn_mask: [B, Tx, Ty]\n    B, Tx, Ty = logp.size()\n    device = logp.device\n    logp = logp * attn_mask  # [B, Tx, Ty]\n    path = torch.zeros_like(logp)  # [B, Tx, Ty]\n    max_neg_val = torch.tensor(-1e32, dtype=logp.dtype, device=device)\n\n    x_len = attn_mask[:, :, 0].sum(dim=1).long()  # [B]\n    y_len = attn_mask[:, 0, :].sum(dim=1).long()  # [B]\n\n    for b in range(B):\n        path[b, x_len[b] - 1, y_len[b] - 1] = 1\n\n    # logp to cumulative logp\n    logp[:, 1:, 0] = max_neg_val\n\n    for ty in range(1, Ty):\n        logp_prev_frame_1 = logp[:, :, ty - 1]  # [B, Tx]\n        logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1)  # [B, Tx]\n        logp_prev_frame_2[:, 0] = max_neg_val\n        logp_prev_frame_max = torch.where(logp_prev_frame_1 > logp_prev_frame_2, logp_prev_frame_1, logp_prev_frame_2)\n        logp[:, :, ty] += logp_prev_frame_max\n\n    ids = torch.ones_like(x_len, device=device) * (x_len - 1)  # [B]\n    arange = torch.arange(B, device=device)\n    path = path.permute(2, 0, 1).contiguous()  # [Ty, B, Tx]\n    attn_mask = attn_mask.permute(2, 0, 1).contiguous()  # [Ty, B, Tx]\n    y_len_minus_1 = y_len - 1  # [B]\n    for ty in range(Ty - 1, 0, -1):\n        logp_prev_frame_1 = logp[:, :, ty - 1]  # [B, Tx]\n        logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1)  # [B, Tx]\n        logp_prev_frame_2[:, 0] = max_neg_val\n        direction = torch.where(logp_prev_frame_1 > logp_prev_frame_2, 0, -1)  # [B, Tx]\n        gathered_dir = torch.gather(direction, 1, ids.view(-1, 1)).view(-1)  # [B]\n        gathered_dir.masked_fill_(ty > y_len_minus_1, 0)\n        ids.add_(gathered_dir)\n        path[ty - 1, arange, ids] = 1\n    path *= attn_mask\n    path = path.permute(1, 2, 0)  # [B, Tx, Ty]\n    return path\n\n\n@torch.no_grad()\ndef maximum_path2(logp: torch.Tensor, attn_mask: torch.Tensor):\n    @torch.jit.script\n    def cumulative_logp(logp, attn_mask):\n        B, Tx, Ty = logp.size()\n        device = logp.device\n        logp = logp * attn_mask  # [B, Tx, Ty]\n        path = torch.zeros_like(logp)  # [B, Tx, Ty]\n        max_neg_val = torch.tensor(-1e32, dtype=logp.dtype, device=device)\n\n        x_len = attn_mask[:, :, 0].sum(dim=1).long()  # [B]\n        y_len = attn_mask[:, 0, :].sum(dim=1).long()  # [B]\n\n        for b in range(B):\n            path[b, x_len[b] - 1, y_len[b] - 1] = 1\n\n        # logp to cumulative logp\n        logp[:, 1:, 0] = max_neg_val\n\n        for ty in range(1, Ty):\n            logp_prev_frame_1 = logp[:, :, ty - 1]  # [B, Tx]\n            logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1)  # [B, Tx]\n            logp_prev_frame_2[:, 0] = max_neg_val\n            logp_prev_frame_max = torch.where(\n                logp_prev_frame_1 > logp_prev_frame_2, logp_prev_frame_1, logp_prev_frame_2\n            )\n            logp[:, :, ty] += logp_prev_frame_max\n        return logp, x_len, y_len, path\n\n    device = logp.device\n    logp, x_len, y_len, path = cumulative_logp(logp, attn_mask)\n    B, Tx, Ty = logp.size()\n    logp = logp.detach().cpu().numpy()\n    x_len = x_len.detach().cpu().numpy()\n    y_len = y_len.detach().cpu().numpy()\n    path = path.detach().cpu().numpy()\n    # backtracking (naive)\n    for b in range(B):\n        idx = x_len[b] - 1\n        path[b, x_len[b] - 1, y_len[b] - 1] = 1\n        for ty in range(y_len[b] - 1, 0, -1):\n            if idx != 0 and logp[b, idx - 1, ty - 1] > logp[b, idx, ty - 1]:\n                idx = idx - 1\n            path[b, idx, ty - 1] = 1\n    path = torch.from_numpy(path).to(device)\n    return path\n\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n    name='super-monotonic-align',\n    version='1.0.0',\n    packages=find_packages(include=['super_monotonic_align', 'super_monotonic_align.*'])\n)\n"
  },
  {
    "path": "super_monotonic_align/__init__.py",
    "content": "import torch\nfrom super_monotonic_align.core import maximum_path_triton\n\n@torch.no_grad()\ndef maximum_path(value, mask, dtype=torch.float32):\n    \"\"\" Triton optimized version.\n    value: [b, t_x, t_y]\n    mask: [b, t_x, t_y]\n    skip_mask: [b, t_x]\n    \"\"\"\n    # check value is contiguous\n    value = value.contiguous()\n    # Use masked_fill_ to avoid new tensor creation\n    value = value.masked_fill_(mask.logical_not(), 0)\n    path = torch.zeros_like(value, dtype=dtype)\n    t_x_max = mask.sum(1)[:, 0].to(torch.int32) \n    t_y_max = mask.sum(2)[:, 0].to(torch.int32)\n    path = maximum_path_triton(path, value, t_x_max, t_y_max)\n    return path"
  },
  {
    "path": "super_monotonic_align/core.py",
    "content": "import torch\n\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef maximum_path(\n    path, value, t_x, t_y,\n    B, T, S,\n    max_neg_val,\n    BLOCK_SIZE_X: tl.constexpr\n    ):\n    batch = tl.program_id(axis=0)\n    path += batch * T * S\n    value += batch * T * S\n    x_length = tl.load(t_x + batch)\n    y_length = tl.load(t_y + batch)\n    offs_prev = tl.arange(0, BLOCK_SIZE_X)\n    init = tl.where(offs_prev ==0, tl.load(value), max_neg_val)\n    # for j in range(0,1,1):  # set the first column to max_neg_val without init point\n    tl.store(value + offs_prev * S, init, mask=offs_prev < x_length)\n    for j in range(1, y_length, 1):\n        v_cur= tl.load(value + (offs_prev) * S + (j-1), mask=(offs_prev < x_length), other=max_neg_val)\n        v_prev =tl.load(value + (offs_prev-1) * S + (j-1), mask=(0 < offs_prev) & (offs_prev < x_length), other=max_neg_val)\n        # compare v_cur and v_prev, and update v with larger value\n        v = (tl.maximum(v_cur, v_prev) + tl.load(value + (offs_prev) * S + j, mask=(offs_prev < x_length)))\n        tl.store(value + (offs_prev) * S + j, v, mask=(offs_prev < x_length))\n\n    index = x_length-1\n    for j in range(y_length-1,-1,-1):\n        tl.store(path + (index) * S + j, 1)\n        if (index > 0): # (index == j) is not checked due to max_neg_val init\n            v_left = tl.load(value+ (index) * S+ j-1)#.to(tl.float32)\n            v_leftdown =  tl.load(value+(index-1) * S + j-1)#.to(tl.float32)\n            if (v_left < v_leftdown):\n                index += - 1\n            \n                        \n@torch.no_grad()\ndef maximum_path_triton(path, value, t_x, t_y, max_neg_val=-1e32):\n    B,T,S = path.shape\n    BLOCK_SIZE_X = max(triton.next_power_of_2(T), 16)\n    num_warps = 1 # Need to be 1 to prevent wrong output by slicing the operation\n    with torch.cuda.device(value.device.index):\n        maximum_path[(B, )](\n            path, value, t_x, t_y, \n            B, T, S,\n            max_neg_val = max_neg_val,\n            num_warps = num_warps,\n            BLOCK_SIZE_X = BLOCK_SIZE_X)\n    return path\n\n"
  },
  {
    "path": "test.py",
    "content": "import torch\nimport triton\nfrom super_monotonic_align import maximum_path as maximum_path_trion\nfrom cython_monotonic_align import maximum_path as maximum_path_cython\nfrom jit_monotonic_align import maximum_path1 as maximum_path_jit_v1\nfrom jit_monotonic_align import maximum_path2 as maximum_path_jit_v2\n\n\ndef identical_test(B,T,S):\n    value = torch.randn((B, T, S), dtype=torch.float32, device='cuda')\n    attn_mask = torch.ones((B, T, S), dtype=torch.int32, device='cuda')\n    path_c = maximum_path_cython(value, attn_mask)\n    path_jit1 = maximum_path_jit_v1(value, attn_mask)\n    path_jit2 = maximum_path_jit_v2(value, attn_mask)\n    path_tri = maximum_path_trion(value.clone(), attn_mask)\n\n    # not 100% equal due to precision issue\n    assert torch.allclose(path_c, path_tri, atol=1e-2, rtol=0), f\"Failed on shape=({B,T,S})\\n{path_c}\\n{path_tri}\\ndiff:{(path_c-path_tri).abs().sum()}\"\n    assert torch.allclose(path_c, path_jit1, atol=1e-2, rtol=0), f\"Failed on shape=({B,T,S})\\n{path_c}\\n{path_jit1}\\ndiff:{(path_c-path_jit1).abs().sum()}\"\n    assert torch.allclose(path_c, path_jit2, atol=1e-2, rtol=0), f\"Failed on shape=({B,T,S})\\n{path_c}\\n{path_jit2}\\ndiff:{(path_c-path_jit2).abs().sum()}\"\n\n# benchmark\n@triton.testing.perf_report(\n    triton.testing.Benchmark(\n        x_names=['T'],\n        x_vals=[128 * i for i in range(1, 17)],\n        line_arg='provider',\n        line_vals= ['triton', 'jit_v1', 'jit_v2', 'cython'],\n        line_names=['Triton', 'JIT_v1', 'JIT_v2', 'Cython'],\n        styles=[('blue', '-'), ('green', '-'), ('red', '-'), ('orange', '-')],\n        ylabel='ms',\n        plot_name='MAS in ms',\n        y_log=True,\n        args={'B': 16},\n    ))\ndef bench_mas(B, T, provider, device='cuda'):\n    from cython_monotonic_align import maximum_path as maximum_path_cython\n    # create data\n    quantiles = [0.5, 0.2, 0.8]\n\n    S = 4*T\n    value = torch.randn((B, T, S), dtype=torch.float32, device=device)\n    attn_mask = torch.ones((B, T, S), dtype=torch.int32, device=device)\n \n    # utility functions\n    if provider == 'triton':\n\n        def y_fwd():\n            return maximum_path_trion(value, attn_mask)  # noqa: F811, E704\n\n    if provider == 'cython':\n\n        def y_fwd():\n            return maximum_path_cython(value, attn_mask)  # noqa: F811, E704\n    \n    if provider == 'jit_v1':\n            \n        def y_fwd():\n            return maximum_path_jit_v1(value, attn_mask)\n        \n    if provider == 'jit_v2':\n\n        def y_fwd():\n            return maximum_path_jit_v2(value, attn_mask)\n        \n    ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)\n\n    return (ms), (max_ms), (min_ms)\n\nif __name__ == \"__main__\":\n    for (b,t,s) in [(32, 16, 16), (32, 128, 512), (32, 256, 1024), (32, 511, 2048)]:\n        identical_test(b,t,s)\n        print(f\"Passed on shape=({b},{t},{s})\")\n    bench_mas.run(save_path='.', print_data=True)\n"
  }
]