Repository: supertone-inc/super-monotonic-align Branch: main Commit: 9bb1cb3a6fba Files: 11 Total size: 19.0 KB Directory structure: gitextract_pe14j_ce/ ├── LICENSE ├── README.md ├── assets/ │ └── MAS.csv ├── cython_monotonic_align/ │ ├── __init__.py │ ├── core.pyx │ └── setup.py ├── jit_monotonic_align/ │ └── __init__.py ├── setup.py ├── super_monotonic_align/ │ ├── __init__.py │ └── core.py └── test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Supertone Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Super-Monotonic-Alignment-Search [![TechnicalReport](https://img.shields.io/badge/TechnicalReport-2409.07704-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2409.07704) This repo contains [Triton-Lang](https://github.com/triton-lang/triton) and PyTorch implementation of the monotonic alignment search (MAS), originally from [Glow-TTS](https://arxiv.org/abs/2005.11129). MAS is an effective algorithm for estimating the alignment between paired speech and text in a self-supervised manner. ![Image0](./assets/memory_read_write.png) The authors of Glow-TTS noted: > "The time complexity of the algorithm is O(T_{text} × T_{mel}). Even though the algorithm is difficult to parallelize, it runs efficiently on CPU without the need for GPU executions. In our experiments, it spends less than 20 ms on each iteration, which amounts to less than 2% of the total training time. Furthermore, we do not need MAS during inference, as the duration predictor is used to estimate the alignment." However, we found three issues while using MAS. 1. MAS can be parallelized in the text-length dimension, while the original implementation uses nested loops. 2. CPU execution consumes an inordinate amount of time for large inputs due to the need to copy large tensors between CPU and GPU. 3. The hard-coded value of max_neg_val at -1e9 is insufficient to prevent alignment mismatches in the upper diagonal parts. Therefore, we implemented a Triton kernel `super_monotonic_align` and PyTorch code `jit_monotonic_align` to accelerate MAS on GPU without inter-device copy. # Requirments 1. PyTorch (tested with version `torch==2.3.0+cu121`) 2. Triton-Lang (tested with version `triton==2.3.0`) 3. Cython (optional for bench, tested with version `Cython== 0.29.36`) Please ensure you have these packages installed to run the code in this repository, as version checks are not enforced. # How to use 1. Install super-monotonic-align ``` git clone git@github.com:supertone-inc/super-monotonic-align.git cd super-monotonic-align; pip install -e ./ ``` or ``` pip install git+https://github.com/supertone-inc/super-monotonic-align.git ``` 2. Import `super_monotonic_align` and use it! ```python from super_monotonic_align import maximum_path ... # You need to know value's value is modified by triton kernel. # If you want to keep value without changing, you need to clone it before maximum_path. # B: batch_size, T: text_length, S: audio_length value = torch.randn((B, T, S), dtype=torch.float32, device='cuda') attn_mask = torch.ones((B, T, S), dtype=torch.int32, device='cuda') # path: [B,T,S] tensor, you can specify path's dtype, default=torch.float32 path = maximum_path(value, attn_mask, dtype=torch.bool) ``` ## Warning Please **check your input shape** before use. Thanks to [codeghees](https://github.com/codeghees) for the issue, our implementation uses the shape \[B, T, S\], identical to Glow-TTS version, while the VITS implementation uses the shape \[B, S, T\]. For now, we recommend to transpose it if you using \[B, S, T\] shaped input, but we will soon release an option that supprots \[B, S, T\] as well. # Benchmark ``` MAS in ms: T Triton JIT_v1 JIT_v2 Cython 0 128.0 0.447488 83.742203 53.222176 8.819136 1 256.0 1.616896 155.424774 104.632477 43.533665 2 384.0 3.430400 325.307404 237.820435 136.257538 3 512.0 5.838848 439.984131 344.654236 304.981201 4 640.0 9.070592 532.910095 452.141907 462.405304 5 768.0 12.249088 655.960083 587.169739 488.272858 6 896.0 15.203328 557.997070 620.148315 863.919067 7 1024.0 19.778561 627.986450 815.933167 1299.567871 8 1152.0 33.276928 706.022400 968.533813 1467.056885 9 1280.0 39.800835 792.861694 1215.021240 1930.171509 10 1408.0 47.456257 903.750671 1289.656250 2231.598145 11 1536.0 59.238914 953.907227 1523.870972 2959.377930 12 1664.0 70.068741 1031.818237 2004.299438 3073.532471 13 1792.0 82.205696 1558.200317 2359.347900 3930.776367 14 1920.0 99.634689 1183.214600 2512.063477 4374.311035 15 2048.0 107.218948 1261.682739 2889.841797 7792.640137 ``` The Triton MAS implementation is at least 19 times faster and up to 72 times faster than the Cython implementation. PyTorch JIT implementations are faster than the Cython implementation for large-sized tensors, especially version v1, which does not involve inter-device copying. | ms in linear scale | ms in log scale | |----------|----------| | ![Image 1](./assets/MAS.png) | ![Image 2](./assets/MAS_log.png) | ## How to run benchmark ```bash cd cython_monotonic_align; mkdir cython_monotonic_align; python setup.py build_ext --inplace cd ../super_monotonic_align; pip install -e ./ cd ../; python test.py ``` # References This implementation uses code from following repositories: - [jaywalnut310's Official Glow-TTS Implementation](https://github.com/jaywalnut310/glow-tts) - [OpenAI's Triton-Lang Tutorials](https://github.com/triton-lang/triton) - [Tri Dao's FlashAttention (memory hierarchy)](https://github.com/Dao-AILab/flash-attention) # Acknowledgement This work is supported by Supertone Inc. and HYBE Corp. We thank Jinhyeok Yang, Juheon Lee, Yechan Yu, Seunghoon Ji, Jacob Morton, Seungu Han, Sungho Lee, Joon Byun, and Hoon Heo of Supertone research team and Hyeong-Seok Choi of ElevenLabs. # Authors - Junhyeok Lee ([jlee843@jhu.edu](mailto:jlee843@jhu.edu)) - Hyoungju Kim ([hyeongju@supertone.ai](mailto:hyeongju@supertone.ai)) If this repository useful for your research, please consider citing (with Glow-TTS or VITS)! ```bib @article{supermas, title={{Super Monotonic Alignment Search}}, author={Lee, Junhyeok and Kim, Hyeongju}, journal={arXiv preprint arXiv:2409.07704}, year={2024} } ``` Feel free to create an issue if you encounter any problems or have any questions. Additionally, [Supertone](https://supertone.ai) is hiring TTS researchers. If you are interested, please check out our career opportunities! ================================================ FILE: assets/MAS.csv ================================================ T,Triton,JIT_v1,JIT_v2,Cython 128.0,0.4,83.7,53.2,8.8 256.0,1.6,155.4,104.6,43.5 384.0,3.4,325.3,237.8,136.3 512.0,5.8,440.0,344.7,305.0 640.0,9.1,532.9,452.1,462.4 768.0,12.2,656.0,587.2,488.3 896.0,15.2,558.0,620.1,863.9 1024.0,19.8,628.0,815.9,1299.6 1152.0,33.3,706.0,968.5,1467.1 1280.0,39.8,792.9,1215.0,1930.2 1408.0,47.5,903.8,1289.7,2231.6 1536.0,59.2,953.9,1523.9,2959.4 1664.0,70.1,1031.8,2004.3,3073.5 1792.0,82.2,1558.2,2359.3,3930.8 1920.0,99.6,1183.2,2512.1,4374.3 2048.0,107.2,1261.7,2889.8,7792.6 ================================================ FILE: cython_monotonic_align/__init__.py ================================================ # modified from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/__init__.py import numpy as np import torch from .cython_monotonic_align.core import maximum_path_c def maximum_path(value, mask): """ Cython optimised version. value: [b, t_x, t_y] mask: [b, t_x, t_y] """ value = value * mask device = value.device dtype = value.dtype value = value.data.cpu().numpy().astype(np.float32) path = np.zeros_like(value).astype(np.int32) mask = mask.data.cpu().numpy() t_x_max = mask.sum(1)[:, 0].astype(np.int32) t_y_max = mask.sum(2)[:, 0].astype(np.int32) maximum_path_c(path, value, t_x_max, t_y_max) return torch.from_numpy(path).to(device=device, dtype=dtype) ================================================ FILE: cython_monotonic_align/core.pyx ================================================ # copied from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/core.pyx import numpy as np cimport numpy as np cimport cython from cython.parallel import prange @cython.boundscheck(False) @cython.wraparound(False) cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: cdef int x cdef int y cdef float v_prev cdef float v_cur cdef float tmp cdef int index = t_x - 1 for y in range(t_y): for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): if x == y: v_cur = max_neg_val else: v_cur = value[x, y-1] if x == 0: if y == 0: v_prev = 0. else: v_prev = max_neg_val else: v_prev = value[x-1, y-1] value[x, y] = max(v_cur, v_prev) + value[x, y] for y in range(t_y - 1, -1, -1): path[index, y] = 1 if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): index = index - 1 @cython.boundscheck(False) @cython.wraparound(False) cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e32) nogil: cdef int b = values.shape[0] cdef int i for i in prange(b, nogil=True): maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) ================================================ FILE: cython_monotonic_align/setup.py ================================================ # modified from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/setup.py from distutils.core import setup from Cython.Build import cythonize import numpy setup( name = 'cython_monotonic_align', ext_modules = cythonize("core.pyx"), include_dirs=[numpy.get_include()] ) ================================================ FILE: jit_monotonic_align/__init__.py ================================================ import torch @torch.no_grad() @torch.jit.script def maximum_path1(logp: torch.Tensor, attn_mask: torch.Tensor): # logp: [B, Tx, Ty], attn_mask: [B, Tx, Ty] B, Tx, Ty = logp.size() device = logp.device logp = logp * attn_mask # [B, Tx, Ty] path = torch.zeros_like(logp) # [B, Tx, Ty] max_neg_val = torch.tensor(-1e32, dtype=logp.dtype, device=device) x_len = attn_mask[:, :, 0].sum(dim=1).long() # [B] y_len = attn_mask[:, 0, :].sum(dim=1).long() # [B] for b in range(B): path[b, x_len[b] - 1, y_len[b] - 1] = 1 # logp to cumulative logp logp[:, 1:, 0] = max_neg_val for ty in range(1, Ty): logp_prev_frame_1 = logp[:, :, ty - 1] # [B, Tx] logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1) # [B, Tx] logp_prev_frame_2[:, 0] = max_neg_val logp_prev_frame_max = torch.where(logp_prev_frame_1 > logp_prev_frame_2, logp_prev_frame_1, logp_prev_frame_2) logp[:, :, ty] += logp_prev_frame_max ids = torch.ones_like(x_len, device=device) * (x_len - 1) # [B] arange = torch.arange(B, device=device) path = path.permute(2, 0, 1).contiguous() # [Ty, B, Tx] attn_mask = attn_mask.permute(2, 0, 1).contiguous() # [Ty, B, Tx] y_len_minus_1 = y_len - 1 # [B] for ty in range(Ty - 1, 0, -1): logp_prev_frame_1 = logp[:, :, ty - 1] # [B, Tx] logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1) # [B, Tx] logp_prev_frame_2[:, 0] = max_neg_val direction = torch.where(logp_prev_frame_1 > logp_prev_frame_2, 0, -1) # [B, Tx] gathered_dir = torch.gather(direction, 1, ids.view(-1, 1)).view(-1) # [B] gathered_dir.masked_fill_(ty > y_len_minus_1, 0) ids.add_(gathered_dir) path[ty - 1, arange, ids] = 1 path *= attn_mask path = path.permute(1, 2, 0) # [B, Tx, Ty] return path @torch.no_grad() def maximum_path2(logp: torch.Tensor, attn_mask: torch.Tensor): @torch.jit.script def cumulative_logp(logp, attn_mask): B, Tx, Ty = logp.size() device = logp.device logp = logp * attn_mask # [B, Tx, Ty] path = torch.zeros_like(logp) # [B, Tx, Ty] max_neg_val = torch.tensor(-1e32, dtype=logp.dtype, device=device) x_len = attn_mask[:, :, 0].sum(dim=1).long() # [B] y_len = attn_mask[:, 0, :].sum(dim=1).long() # [B] for b in range(B): path[b, x_len[b] - 1, y_len[b] - 1] = 1 # logp to cumulative logp logp[:, 1:, 0] = max_neg_val for ty in range(1, Ty): logp_prev_frame_1 = logp[:, :, ty - 1] # [B, Tx] logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1) # [B, Tx] logp_prev_frame_2[:, 0] = max_neg_val logp_prev_frame_max = torch.where( logp_prev_frame_1 > logp_prev_frame_2, logp_prev_frame_1, logp_prev_frame_2 ) logp[:, :, ty] += logp_prev_frame_max return logp, x_len, y_len, path device = logp.device logp, x_len, y_len, path = cumulative_logp(logp, attn_mask) B, Tx, Ty = logp.size() logp = logp.detach().cpu().numpy() x_len = x_len.detach().cpu().numpy() y_len = y_len.detach().cpu().numpy() path = path.detach().cpu().numpy() # backtracking (naive) for b in range(B): idx = x_len[b] - 1 path[b, x_len[b] - 1, y_len[b] - 1] = 1 for ty in range(y_len[b] - 1, 0, -1): if idx != 0 and logp[b, idx - 1, ty - 1] > logp[b, idx, ty - 1]: idx = idx - 1 path[b, idx, ty - 1] = 1 path = torch.from_numpy(path).to(device) return path ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages setup( name='super-monotonic-align', version='1.0.0', packages=find_packages(include=['super_monotonic_align', 'super_monotonic_align.*']) ) ================================================ FILE: super_monotonic_align/__init__.py ================================================ import torch from super_monotonic_align.core import maximum_path_triton @torch.no_grad() def maximum_path(value, mask, dtype=torch.float32): """ Triton optimized version. value: [b, t_x, t_y] mask: [b, t_x, t_y] skip_mask: [b, t_x] """ # check value is contiguous value = value.contiguous() # Use masked_fill_ to avoid new tensor creation value = value.masked_fill_(mask.logical_not(), 0) path = torch.zeros_like(value, dtype=dtype) t_x_max = mask.sum(1)[:, 0].to(torch.int32) t_y_max = mask.sum(2)[:, 0].to(torch.int32) path = maximum_path_triton(path, value, t_x_max, t_y_max) return path ================================================ FILE: super_monotonic_align/core.py ================================================ import torch import triton import triton.language as tl @triton.jit def maximum_path( path, value, t_x, t_y, B, T, S, max_neg_val, BLOCK_SIZE_X: tl.constexpr ): batch = tl.program_id(axis=0) path += batch * T * S value += batch * T * S x_length = tl.load(t_x + batch) y_length = tl.load(t_y + batch) offs_prev = tl.arange(0, BLOCK_SIZE_X) init = tl.where(offs_prev ==0, tl.load(value), max_neg_val) # for j in range(0,1,1): # set the first column to max_neg_val without init point tl.store(value + offs_prev * S, init, mask=offs_prev < x_length) for j in range(1, y_length, 1): v_cur= tl.load(value + (offs_prev) * S + (j-1), mask=(offs_prev < x_length), other=max_neg_val) v_prev =tl.load(value + (offs_prev-1) * S + (j-1), mask=(0 < offs_prev) & (offs_prev < x_length), other=max_neg_val) # compare v_cur and v_prev, and update v with larger value v = (tl.maximum(v_cur, v_prev) + tl.load(value + (offs_prev) * S + j, mask=(offs_prev < x_length))) tl.store(value + (offs_prev) * S + j, v, mask=(offs_prev < x_length)) index = x_length-1 for j in range(y_length-1,-1,-1): tl.store(path + (index) * S + j, 1) if (index > 0): # (index == j) is not checked due to max_neg_val init v_left = tl.load(value+ (index) * S+ j-1)#.to(tl.float32) v_leftdown = tl.load(value+(index-1) * S + j-1)#.to(tl.float32) if (v_left < v_leftdown): index += - 1 @torch.no_grad() def maximum_path_triton(path, value, t_x, t_y, max_neg_val=-1e32): B,T,S = path.shape BLOCK_SIZE_X = max(triton.next_power_of_2(T), 16) num_warps = 1 # Need to be 1 to prevent wrong output by slicing the operation with torch.cuda.device(value.device.index): maximum_path[(B, )]( path, value, t_x, t_y, B, T, S, max_neg_val = max_neg_val, num_warps = num_warps, BLOCK_SIZE_X = BLOCK_SIZE_X) return path ================================================ FILE: test.py ================================================ import torch import triton from super_monotonic_align import maximum_path as maximum_path_trion from cython_monotonic_align import maximum_path as maximum_path_cython from jit_monotonic_align import maximum_path1 as maximum_path_jit_v1 from jit_monotonic_align import maximum_path2 as maximum_path_jit_v2 def identical_test(B,T,S): value = torch.randn((B, T, S), dtype=torch.float32, device='cuda') attn_mask = torch.ones((B, T, S), dtype=torch.int32, device='cuda') path_c = maximum_path_cython(value, attn_mask) path_jit1 = maximum_path_jit_v1(value, attn_mask) path_jit2 = maximum_path_jit_v2(value, attn_mask) path_tri = maximum_path_trion(value.clone(), attn_mask) # not 100% equal due to precision issue assert torch.allclose(path_c, path_tri, atol=1e-2, rtol=0), f"Failed on shape=({B,T,S})\n{path_c}\n{path_tri}\ndiff:{(path_c-path_tri).abs().sum()}" assert torch.allclose(path_c, path_jit1, atol=1e-2, rtol=0), f"Failed on shape=({B,T,S})\n{path_c}\n{path_jit1}\ndiff:{(path_c-path_jit1).abs().sum()}" assert torch.allclose(path_c, path_jit2, atol=1e-2, rtol=0), f"Failed on shape=({B,T,S})\n{path_c}\n{path_jit2}\ndiff:{(path_c-path_jit2).abs().sum()}" # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=['T'], x_vals=[128 * i for i in range(1, 17)], line_arg='provider', line_vals= ['triton', 'jit_v1', 'jit_v2', 'cython'], line_names=['Triton', 'JIT_v1', 'JIT_v2', 'Cython'], styles=[('blue', '-'), ('green', '-'), ('red', '-'), ('orange', '-')], ylabel='ms', plot_name='MAS in ms', y_log=True, args={'B': 16}, )) def bench_mas(B, T, provider, device='cuda'): from cython_monotonic_align import maximum_path as maximum_path_cython # create data quantiles = [0.5, 0.2, 0.8] S = 4*T value = torch.randn((B, T, S), dtype=torch.float32, device=device) attn_mask = torch.ones((B, T, S), dtype=torch.int32, device=device) # utility functions if provider == 'triton': def y_fwd(): return maximum_path_trion(value, attn_mask) # noqa: F811, E704 if provider == 'cython': def y_fwd(): return maximum_path_cython(value, attn_mask) # noqa: F811, E704 if provider == 'jit_v1': def y_fwd(): return maximum_path_jit_v1(value, attn_mask) if provider == 'jit_v2': def y_fwd(): return maximum_path_jit_v2(value, attn_mask) ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) return (ms), (max_ms), (min_ms) if __name__ == "__main__": for (b,t,s) in [(32, 16, 16), (32, 128, 512), (32, 256, 1024), (32, 511, 2048)]: identical_test(b,t,s) print(f"Passed on shape=({b},{t},{s})") bench_mas.run(save_path='.', print_data=True)